Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions httomolibgpu/recon/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"CGLS3d_tomobar",
"FISTA3d_tomobar",
"ADMM3d_tomobar",
"OSEM3d_tomobar",
]

input_data_axis_labels = ["angles", "detY", "detX"] # set the labels of the input data
Expand Down Expand Up @@ -563,7 +564,7 @@ def FISTA3d_tomobar(
subsets_number: int
The number of the ordered subsets to accelerate convergence. Keep the value bellow 10 to avoid divergence.
data_fidelity: str
Data fidelity given as 'LS' (Least Squares), 'PWLS' (Penalised Weighted LS).
Data fidelity given as 'LS' (Least Squares)
regularisation_type: str
A method to use for regularisation. Currently PD_TV and ROF_TV are available.
regularisation_parameter: float
Expand Down Expand Up @@ -690,7 +691,7 @@ def ADMM3d_tomobar(
subsets_number: int
The number of the ordered subsets to accelerate convergence. The recommended range is between 12 to 24.
data_fidelity: str
Data fidelity given as 'LS' (Least Squares), 'PWLS' (Penalised Weightes LS).
Data fidelity given as 'LS' (Least Squares)
initialisation: str, optional
Initialise ADMM with the reconstructed image to reduce the number of iterations and accelerate. Choose between 'CGLS' or 'SIRT' when data
is noisy and/or undersampled. Choose 'FBP' when the data is of better quality (default) or None.
Expand Down Expand Up @@ -846,6 +847,125 @@ def ADMM3d_tomobar(
return cp.require(cp.swapaxes(reconstruction, 0, 1), requirements="C")


## %%%%%%%%%%%%%%%%%%%%%%% OSEM reconstruction %%%%%%%%%%%%%%%%%%%%%%%%%%%% ##
def OSEM3d_tomobar(
data: cp.ndarray,
angles: np.ndarray,
center: Optional[float] = None,
detector_pad: Union[bool, int] = False,
recon_size: Optional[int] = None,
recon_mask_radius: Optional[float] = 0.95,
iterations: int = 20,
subsets_number: int = 12,
regularisation_type: Literal["ROF_TV", "PD_TV"] = "PD_TV",
regularisation_parameter: float = 1.0,
regularisation_iterations: int = 30,
regularisation_half_precision: bool = True,
nonnegativity: bool = True,
gpu_id: int = 0,
) -> cp.ndarray:
"""
Ordered-Subsets Expectation-Maximisation method is the accelerated Maximum Likelihood Expectation-Maximisation (MLEM) algorithm.
Can be coupled with various types of regularisation or denoising operations :cite:`kazantsev2019ccpi` (currently accepts ROF_TV and PD_TV regularisations only).
Should be applied to reconstruct emission-type measurements, e.g., XRF tomography measurements.

Parameters
----------
data : cp.ndarray
Projection data as a CuPy array.
angles : np.ndarray
An array of angles given in radians.
center : float, optional
The center of rotation (CoR).
detector_pad : bool, int
Detector width padding with edge values to remove circle/arc type artifacts in the reconstruction. Set to True to perform
an automated padding or specify a certain value as an integer.
recon_size : int, optional
The squared size of the reconstructed slice. By default (None), the reconstructed size will be the dimension of the horizontal detector.
recon_mask_radius: float, optional
The radius of the circular mask that applies to the reconstructed slice in order to crop
out some undesirable artifacts. The values outside the given diameter will be set to zero.
To implement the cropping one can use the range [0.7-1.0] or set to None (2.0) when no cropping is needed.
iterations : int
The number of OSEM algorithm iterations. For OS method 20 interations is normally sufficient, while for MLEM one should run 300-500 iterations.
subsets_number: int
The number of the ordered subsets to accelerate convergence. One can set 'subsets_number' to 1 to achieve MLEM, but the number of 'iterations' should be also increased.
data_fidelity: str
Data fidelity given as 'LS' (Least Squares)
regularisation_type: str
A method to use for regularisation. Currently PD_TV and ROF_TV are available.
regularisation_parameter: float
The main regularisation parameter to control the amount of smoothing/noise removal. Larger values lead to stronger smoothing.
regularisation_iterations: int
The number of iterations for regularisers (aka INNER iterations).
regularisation_half_precision: bool
Perform faster regularisation computation in half-precision with a very minimal sacrifice in quality.
nonnegativity : bool
Impose nonnegativity constraint on the reconstructed image.
gpu_id : int
A GPU device index to perform operation on.

Returns
-------
cp.ndarray
The OSEM reconstructed volume as a CuPy array.
"""
### Data and parameters checks ###
methods_name = "OSEM3d_tomobar"
__common_data_parameters_check(
data,
angles,
methods_name,
center,
detector_pad,
recon_size,
recon_mask_radius,
gpu_id,
)
__common_iterative_basic_parameters_check(methods_name, iterations, nonnegativity)
__common_iterative_parameters_check(
methods_name,
subsets_number,
regularisation_type,
regularisation_parameter,
regularisation_iterations,
regularisation_half_precision,
)
###################################

RecToolsCP = _instantiate_iterative_recon_class(
data,
angles,
center,
detector_pad,
recon_size,
subsets_number,
gpu_id,
)

_data_ = {
"projection_data": data,
"data_axes_labels_order": input_data_axis_labels,
}

_algorithm_ = {
"iterations": iterations,
"nonnegativity": nonnegativity,
"recon_mask_radius": recon_mask_radius,
}

_regularisation_ = {
"method": regularisation_type, # Selected regularisation method
"regul_param": regularisation_parameter, # Regularisation parameter
"iterations": regularisation_iterations, # The number of regularisation iterations
"half_precision": regularisation_half_precision, # enabling half-precision calculation
}

reconstruction = RecToolsCP.OSEM(_data_, _algorithm_, _regularisation_)
cp._default_memory_pool.free_all_blocks()
return cp.require(cp.swapaxes(reconstruction, 0, 1), requirements="C")


## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ##
def _instantiate_direct_recon_class(
data: cp.ndarray | Tuple[int, int, int],
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def distortion_correction_path(test_data_path):
return os.path.join(test_data_path, "distortion-correction")


@pytest.fixture(scope="session")
def data_XRFfile(test_data_path):
in_file = os.path.join(test_data_path, "Ga-Ka_aligned.npz")
return np.load(in_file)


# only load from disk once per session, and we use np.copy for the elements,
# to ensure data in this loaded file stays as originally loaded
@pytest.fixture(scope="session")
Expand Down Expand Up @@ -141,6 +147,16 @@ def detector_x(host_detector_x, ensure_clean_memory):
return cp.asarray(host_detector_x)


@pytest.fixture
def raw_data_Xrf(data_XRFfile):
return np.float32(np.copy(data_XRFfile["arr_0"]))


@pytest.fixture
def angles_data_Xrf(data_XRFfile):
return np.float32(np.copy(data_XRFfile["arr_1"]))


class MaxMemoryHook(cp.cuda.MemoryHook):
def __init__(self, initial=0):
self.max_mem = initial
Expand Down
26 changes: 26 additions & 0 deletions tests/test_recon/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CGLS3d_tomobar,
FISTA3d_tomobar,
ADMM3d_tomobar,
OSEM3d_tomobar,
)
from numpy.testing import assert_allclose
import time
Expand Down Expand Up @@ -485,6 +486,31 @@ def test_reconstruct_ADMM3d_tomobar_rof_tv(data, flats, darks, ensure_clean_memo
assert recon_data.dtype == np.float32


def test_reconstruct_OSEM3d_tomobar_XRF_dataset(
raw_data_Xrf,
angles_data_Xrf,
):
_, detY, detX = np.shape(raw_data_Xrf)
data_cp = cp.asarray(np.float32(raw_data_Xrf), order="C")
angles_rad = np.deg2rad(angles_data_Xrf)

args = {
"angles": angles_rad,
"center": -0.5,
"detector_pad": False,
"recon_mask_radius": 2.0,
"iterations": 10,
"regularisation_parameter": 1.5,
}

recon_data = OSEM3d_tomobar(data=data_cp, **args)

assert recon_data.flags.c_contiguous
recon_data = cp.asnumpy(recon_data)
assert_allclose(np.min(recon_data), -5.1245663e-16, rtol=1e-04)
assert_allclose(np.max(recon_data), 71016.34, rtol=1e-04)


@pytest.mark.perf
def test_FBP3d_tomobar_performance(ensure_clean_memory):
dev = cp.cuda.Device()
Expand Down
1 change: 1 addition & 0 deletions zenodo-tests/test_recon/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SIRT3d_tomobar,
FISTA3d_tomobar,
ADMM3d_tomobar,
OSEM3d_tomobar,
)
from httomolibgpu.misc.morph import sino_360_to_180
from numpy.testing import assert_allclose
Expand Down
Loading