Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions LION/CTtools/ct_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import tomosipo as ts
import torch
import torchvision.transforms.functional as TF

# AItomotools imports
from LION.CTtools.ct_geometry import Geometry
Expand Down Expand Up @@ -54,7 +55,8 @@ def sinogram_add_noise(
proj,
I0=1000,
sigma=5,
cross_talk=0.05,
sigma_blur=0.3015,
ks_value=3,
flat_field=None,
dark_field=None,
enable_gradients=False,
Expand All @@ -63,17 +65,25 @@ def sinogram_add_noise(
Wraper for _sinogram_add_noise to support gradients
"""
if enable_gradients:
sino = _sinogram_add_noise(proj, I0, sigma, cross_talk, flat_field, dark_field)
sino = _sinogram_add_noise(
proj, I0, sigma, sigma_blur, ks_value, flat_field, dark_field
)
else:
with torch.no_grad():
sino = _sinogram_add_noise(
proj.detach(), I0, sigma, cross_talk, flat_field, dark_field
proj.detach(), I0, sigma, sigma_blur, ks_value, flat_field, dark_field
)
return sino


def _sinogram_add_noise(
proj, I0=1000, sigma=5, cross_talk=0.05, flat_field=None, dark_field=None
proj,
I0=1000,
sigma=5,
sigma_blur=0.3015,
ks_value=3,
flat_field=None,
dark_field=None,
):
"""
Adds realistic noise to sinograms.
Expand Down Expand Up @@ -101,7 +111,8 @@ def _sinogram_add_noise(
max_val = torch.amax(
proj
) # alternatively the highest power of 2 close to this value, but lets leave it as is.

if max_val <= 0:
max_val = 1.0
Im = I0 * torch.exp(-proj / max_val)

# Uncorrect the flat fields
Expand All @@ -111,17 +122,11 @@ def _sinogram_add_noise(
Im = torch.poisson(Im)

# Detector cross talk
ks = int(sigma_blur * ks_value) * 2 + 1
if ks < 3:
ks = 3
Im = TF.gaussian_blur(Im, kernel_size=[ks, ks], sigma=[sigma_blur, sigma_blur])

kernel = torch.tensor(
[[0.0, 0.0, 0.0], [cross_talk, 1, cross_talk], [0.0, 0.0, 0.0]]
).view(1, 1, 3, 3).repeat(1, 1, 1, 1) / (1 + 2 * cross_talk)

conv = torch.nn.Conv2d(1, 1, 3, bias=False, padding="same")
with torch.no_grad():
conv.weight = torch.nn.Parameter(kernel)
conv = conv.to(dev)

Im = conv(Im.unsqueeze(0))[0]
# Electronic noise:
Im = Im + sigma * torch.randn(Im.shape, device=dev)

Expand Down
26 changes: 13 additions & 13 deletions LION/experiments/ct_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self, experiment_params=None, dataset="LIDC-IDRI", datafolder=None)
self.geometry = self.experiment_params.geometry
self.dataset = dataset
if hasattr(self.param, "noise_params"):
self.sino_fun = lambda sino, I0=self.param.noise_params.I0, sigma=self.param.noise_params.sigma, cross_talk=self.param.noise_params.cross_talk: ct.sinogram_add_noise(
sino, I0=I0, sigma=sigma, cross_talk=cross_talk
self.sino_fun = lambda sino, I0=self.param.noise_params.I0, sigma=self.param.noise_params.sigma, sigma_blur=self.param.noise_params.sigma_blur: ct.sinogram_add_noise(
sino, I0=I0, sigma=sigma, sigma_blur=sigma_blur
)

@staticmethod
Expand All @@ -62,8 +62,8 @@ def __get_dataset(self, mode):
self.param.data_loader_params.noise_params.I0 = (
self.param.noise_params.I0
)
self.param.data_loader_params.noise_params.cross_talk = (
self.param.noise_params.cross_talk
self.param.data_loader_params.noise_params.sigma_blur = (
self.param.noise_params.sigma_blur
)
self.param.data_loader_params.add_noise = True
dataloader = deteCT(
Expand Down Expand Up @@ -112,7 +112,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 1000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015
param.data_loader_params = Experiment.get_dataset_parameters(
dataset, geometry=param.geometry
)
Expand All @@ -135,7 +135,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 3500
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015
param.data_loader_params = Experiment.get_dataset_parameters(
dataset, geometry=param.geometry
)
Expand All @@ -158,7 +158,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 10000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015
param.data_loader_params = Experiment.get_dataset_parameters(
dataset, geometry=param.geometry
)
Expand All @@ -180,7 +180,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 3500
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

if dataset == "LIDC-IDRI":
# Parameters for the LIDC-IDRI dataset
Expand All @@ -207,7 +207,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 1000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

if dataset == "LIDC-IDRI":
# Parameters for the LIDC-IDRI dataset
Expand Down Expand Up @@ -235,7 +235,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 10000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

param.data_loader_params = Experiment.get_dataset_parameters(
dataset, geometry=param.geometry
Expand All @@ -258,7 +258,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 3500
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

if dataset == "LIDC-IDRI":
# Parameters for the LIDC-IDRI dataset
Expand All @@ -285,7 +285,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 1000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

if dataset == "LIDC-IDRI":
# Parameters for the LIDC-IDRI dataset
Expand Down Expand Up @@ -313,7 +313,7 @@ def default_parameters(dataset="LIDC-IDRI"):
param.noise_params = LIONParameter()
param.noise_params.I0 = 10000
param.noise_params.sigma = 5
param.noise_params.cross_talk = 0.05
param.noise_params.sigma_blur = 0.3015

param.data_loader_params = Experiment.get_dataset_parameters(
dataset, geometry=param.geometry
Expand Down
43 changes: 28 additions & 15 deletions LION/losses/SUREpgImage.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,55 @@
import torch

class SUREpgLoss():

def __init__(self, zeta: float, sigma2: float, eps1: float = 1e-3, eps2: float = 1e-3, kappa: float = 1.0):
class SUREpgLoss:
def __init__(
self,
zeta: float,
sigma2: float,
eps1: float = 1e-3,
eps2: float = 1e-3,
kappa: float = 1.0,
):
self.zeta = zeta
self.sigma2 = sigma2
self.eps1 = eps1
self.eps2 = eps2
self.kappa = kappa

self.p = (1/2) * (1 + self.kappa / (self.kappa**2 + 4)**0.5)
self.p = (1 / 2) * (1 + self.kappa / (self.kappa**2 + 4) ** 0.5)
self.q = 1 - self.p
self.a = (self.q / self.p)**0.5
self.b = (self.p / self.q)**0.5
self.a = (self.q / self.p) ** 0.5
self.b = (self.p / self.q) ** 0.5

def __call__(self, model, y):
B=y.shape[0]
B = y.shape[0]
N_per_img = y.shape[1] * y.shape[2] * y.shape[3]

fy=model(y)
loss = ((fy - y) ** 2).sum(dim=(1,2,3)) - self.zeta * y.sum(dim=(1,2,3)) - self.sigma2 * N_per_img
fy = model(y)
loss = (
((fy - y) ** 2).sum(dim=(1, 2, 3))
- self.zeta * y.sum(dim=(1, 2, 3))
- self.sigma2 * N_per_img
)

#1st derivative MC
# 1st derivative MC
delta1 = torch.randn_like(y)
fy_perturbated = model(y + self.eps1 * delta1)

u = self.zeta * y + self.sigma2
mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1,2,3))
mc1 = (delta1 * u * (fy_perturbated - fy)).sum(dim=(1, 2, 3))
loss += 2.0 * mc1 / self.eps1

#2nd derivative MC
# 2nd derivative MC
u_rand = torch.rand_like(y)
delta2 = torch.where(u_rand < self.p,-self.a * torch.ones_like(y),+self.b * torch.ones_like(y))
delta2 = torch.where(
u_rand < self.p, -self.a * torch.ones_like(y), +self.b * torch.ones_like(y)
)

fy_plus = model(y + self.eps2 * delta2)
fy_plus = model(y + self.eps2 * delta2)
fy_minus = model(y - self.eps2 * delta2)

mc2 = (delta2 * (fy_plus - 2*fy + fy_minus)).sum(dim=(1,2,3))
mc2 = (delta2 * (fy_plus - 2 * fy + fy_minus)).sum(dim=(1, 2, 3))
loss -= (2 * self.sigma2 * self.zeta / (self.eps2**2 * self.kappa)) * mc2

return loss.mean()/N_per_img
return loss.mean() / N_per_img
101 changes: 101 additions & 0 deletions LION/optimizers/Equivariance2InverseSolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Callable, Optional
import numpy as np
from LION.CTtools.ct_geometry import Geometry
from LION.classical_algorithms.fdk import fdk
from LION.models.LIONmodel import LIONmodel
import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from LION.optimizers.LIONsolver import LIONsolver
from LION.utils.parameter import LIONParameter
import tomosipo as ts
import LION.CTtools.ct_utils as ct_utils
from tomosipo.torch_support import to_autograd
import random
import torchvision.transforms as TT


class Equivariance2InverseSolver(LIONsolver):
def __init__(
self,
model: LIONmodel,
optimizer: Optimizer,
loss_fn,
solver_params: Optional[LIONParameter] = None,
geometry: Geometry = None,
verbose: bool = True,
device: torch.device = None,
) -> None:
print(device)
super().__init__(
model,
optimizer,
loss_fn,
geometry,
verbose,
device,
solver_params=solver_params,
)

self.operator = ct_utils.make_operator(self.geometry)
self.model.geometry = self.geometry
self.model.operator = self.operator
self.projector = to_autograd(self.operator, num_extra_dims=1)
self.recon_fn = self.solver_params.recon_fn

@staticmethod
def default_parameters() -> LIONParameter:
params = LIONParameter()
params.recon_fn = fdk
params.I0 = 500
params.sigma = (50) ** (0.5)
params.sigma_blur = 0.8
return params

def mini_batch_step(self, sinos, targets):
# masking
NP = sinos.shape[2]
YJ_num = torch.randint(0, NP, (1,)).item()
YJ = sinos[:, :, YJ_num, :]

YJc = sinos.clone()
YJc[:, :, YJ_num, :] = 0

weight = NP / (NP - 1)
RJc = self.recon_fn(YJc * weight, self.model.operator)
output_recon_1 = self.model(RJc)

output_sino_1 = self.projector(output_recon_1)
AJ = output_sino_1[:, :, YJ_num, :]
batch_loss = ((AJ - YJ) ** 2).mean()

angle = random.uniform(0, 360)
rotated_output_recon_1 = TT.functional.rotate(
output_recon_1, angle, interpolation=TT.InterpolationMode.BILINEAR
)
rotated_output_recon_1 = torch.clamp(rotated_output_recon_1, min=0.0)
rotated_sinogram = self.projector(rotated_output_recon_1)
rotated_sinogram = torch.clamp(rotated_sinogram, min=0.0)
rotated_sinogram_noisy = ct_utils.sinogram_add_noise(
rotated_sinogram,
I0=self.solver_params.I0,
sigma=self.solver_params.sigma,
sigma_blur=self.solver_params.sigma_blur,
ks_value=3,
flat_field=None,
dark_field=None,
)

rotated_noisy_image = self.recon_fn(rotated_sinogram_noisy, self.model.operator)
output_recon_2 = self.model(rotated_noisy_image)
batch_loss += ((output_recon_2 - rotated_output_recon_1) ** 2).mean()
return batch_loss

# No validation in E2I
def validate(self):
return 0

def reconstruct(self, sinos):
input_recon = self.recon_fn(sinos, self.operator)
output_recon = self.model(input_recon)
return output_recon
3 changes: 0 additions & 3 deletions LION/optimizers/Noise2InverseSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def _make_sub_operators(self) -> list[ts.Operator.Operator]:
ops = []
# maintain a copy of the original angles to restore later
angles = self.geometry.angles.copy()
assert (
len(angles) % self.sino_split_count == 0
), f"Cannot construct {self.sino_split_count} sinogram splits from {len(angles)} view angles. Ensure that sino_split_count divides #view angles"
for k in range(self.sino_split_count):
self.geometry.angles = angles[k :: self.sino_split_count]
sub_op = ct.make_operator(self.geometry)
Expand Down
Loading
Loading