diff --git a/pySNOM/images.py b/pySNOM/images.py index 2009036..5fd42ff 100644 --- a/pySNOM/images.py +++ b/pySNOM/images.py @@ -10,6 +10,7 @@ from scipy.sparse import coo_matrix from scipy.sparse.linalg import spsolve from scipy.ndimage import generic_filter +from scipy.ndimage import shift MeasurementModes = Enum( "MeasurementModes", @@ -562,11 +563,15 @@ def transform(self, image): class CalculateXCorrDrift(Transformation): """Calculates the drift between reference and template image""" - def __init__(self, image_ref): + def __init__(self, image_ref, upsample_factor=100, **kwargs): self.image_ref = image_ref + self.upsample_factor = upsample_factor + self.kwargs = kwargs def transform(self, image): - shift, _, _ = phase_cross_correlation(self.image_ref, image) + shift, _, _ = phase_cross_correlation( + self.image_ref, image, upsample_factor=self.upsample_factor, **self.kwargs + ) return shift @@ -583,44 +588,184 @@ def transform(self, image): class AlignImageStack(Transformation): - """Calculates the drift between the given images and organize the comman areas into an aligned stack""" + """Align a stack of images by estimating and correcting relative drift. + + This transformation estimates pairwise shifts between a reference image + and subsequent images in a stack, then applies those shifts to produce + an aligned image stack cropped to the common overlapping region. + + Parameters + ---------- + upsample_factor : int, optional + Upsampling factor used in the subpixel drift estimation, controlling + the precision of the cross-correlation-based shift calculation. + Higher values yield more precise but slower computations. + + Attributes + ---------- + upsample_factor : int + The upsampling factor used when computing image drift. + + Methods + ------- + calculate(images) + Estimate the drift between a reference image and the remaining images + in the stack, returning a list of (x, y) shift vectors. + transform(images, shifts) + Apply the given shifts to the image stack, correct for drift, and + crop to the common cross-section. + """ - def __init__(self): - pass + def __init__(self, ref_index=0, upsample_factor=100): + self.ref_index = ref_index + self.upsample_factor = upsample_factor def calculate(self, images): + """ + Estimate drift vectors for a stack of images. + + The first image in `images` is used as the reference. Drift for each + subsequent image is computed via cross-correlation against this + reference, returning the corresponding (x, y) shifts. + + Parameters + ---------- + images : sequence of array_like + Sequence of 2D images to be aligned. The first image serves as + the reference for drift estimation. + + Returns + ------- + list of tuple or None + List of (x, y) shift vectors for each image after the first, + where each tuple represents the drift relative to the reference + image. Returns ``None`` if fewer than two images are provided. + """ + shifts = [] - crossrect = [0, 0, np.shape(images[0])[0], np.shape(images[0])[1]] if len(images) > 1: - xcorr = CalculateXCorrDrift(images[0]) + xcorr = CalculateXCorrDrift( + images[self.ref_index], upsample_factor=self.upsample_factor + ) for i in range(len(images)): - if i > 0: + if i != self.ref_index: shifts.append(xcorr.transform(images[i])) - crossrect = shifted_cross_section( - rect1=crossrect, - rect2=[ - -shifts[-1][0], - shifts[-1][1], - np.shape(images[i])[0], - np.shape(images[i])[1], - ], - ) - return shifts, crossrect + else: + shifts.append((0, 0)) + + return shifts else: return None - def transform(self, images, shifts, crossrect): - aligned_stack = [] - for i in range(len(images)): - if i > 0: - shifter = CorrectImageDrift(shifts[i - 1]) - aligned_stack.append(shifter.transform(images[i])) - aligned_stack[i] = cut_image(aligned_stack[i], crossrect) - else: - aligned_stack.append(cut_image(images[i], crossrect)) + def transform(self, images, shifts): + """ + Apply drift correction and crop to common overlap. + + Each image after the first is shifted according to the corresponding + drift vector in `shifts`, while the first image is kept unchanged. + After correction, all images are cropped to their common overlapping + cross-section. + + Parameters + ---------- + images : sequence of array_like + Sequence of 2D images to be aligned, ordered consistently with + the `shifts` list. + shifts : sequence of tuple + Sequence of (x, y) drift vectors as returned by `calculate`, + where each element corresponds to the image at the same index + after the first. + + Returns + ------- + numpy.ndarray + Array containing the aligned and cropped image stack, with shape + (N, H_out, W_out), where N is the number of images and + ``H_out``, ``W_out`` are the dimensions of the common + overlapping region. + """ + + aligned_stack = np.zeros( + (len(images),) + images[0].shape, dtype=images[0].dtype + ) + for k in range(len(images)): + aligned_stack[k] = shift_fill(images[k], shifts[k], fill=0) + + aligned_stack = cut_cross_section(aligned_stack, shifts) + return aligned_stack +def cut_cross_section(image_stack, shifts): + """ + Crop a stack of shifted images to their common overlapping cross-section. + + This function takes an image stack (e.g., a list or array of 2D images) and + a corresponding list of shift vectors applied to each image. It computes + the minimal rectangular region (cross-section) that is common to all shifted + images and returns the cropped stack. + + Parameters + ---------- + image_stack : array_like + A sequence or NumPy array of shape (N, H, W), where N is the number + of images, and H and W are the image height and width, respectively. + shifts : array_like + A sequence of (x, y) shift vectors of shape (N, 2), specifying + the pixel displacements applied to each corresponding image. + + Returns + ------- + numpy.ndarray + A NumPy array of the cropped image stack containing only the + overlapping region across all shifted images. The shape of the + returned array is (N, H', W'), where H' and W' correspond to + the dimensions of the common cross-section. + + Notes + ----- + - This assumes that the shifts are given in pixel units for the x + (horizontal) and y (vertical) directions. + - The function ensures that all images are trimmed to the same region + to maintain alignment across the stack. + """ + + # Calculate the cross-section of all shifted images + xmin, ymin = np.array(shifts)[:, 1].min(), np.array(shifts)[:, 0].min() + xmax, ymax = np.array(shifts)[:, 1].max(), np.array(shifts)[:, 0].max() + xmin, xmax = int(round(xmin)), int(round(xmax)) + ymin, ymax = int(round(ymin)), int(round(ymax)) + + # Cut the images to the common cross-section + shape = np.array(image_stack).shape + print(shape) + slicex = slice(max(xmax, 0), min(shape[2], shape[2] + xmin)) + slicey = slice(max(ymax, 0), min(shape[1], shape[1] + ymin)) + print(slicey, slicex) + + return np.array(image_stack)[:, slicey, slicex] + + +def shift_fill(img, sh, fill=np.nan): + """Shift and fill invalid positions""" + aligned = shift(img, sh, mode="nearest") + + (u, v) = img.shape + + shifty = int(round(sh[0])) + aligned[: max(0, shifty), :] = fill + aligned[min(u, u + shifty) :, :] = fill + + shiftx = int(round(sh[1])) + aligned[:, : max(0, shiftx)] = fill + aligned[:, min(v, v + shiftx) :] = fill + + return aligned + + +################################################################################################ + + def sort_image_stack(images, wns): """Sort the image stack based on the wavenumber list""" @@ -685,42 +830,3 @@ def flatten_stack(imagestack): (imagestack.shape[0], imagestack.shape[1] * imagestack.shape[2]) ) return np.ravel(flattened_values, order="F") - - -def shifted_cross_section(rect1: list, rect2: list): - """Calculates the cross-section of two rectangle shifted to each other""" - x1 = rect1[1] - x2 = rect2[1] - y1 = rect1[0] - y2 = rect2[0] - W1 = rect1[3] - W2 = rect2[3] - H1 = rect1[2] - H2 = rect2[2] - - if y2 > y1: - Hn = H1 - (y2 - y1) - yn = y2 - elif (y2 < y1) and (y1 + H1 > y2 + H2): # Negative shift and higher than H2 - Hn = H2 + (y2 - y1) - yn = y1 - else: - Hn = H1 - yn = y1 - - if x2 > x1: # Positive shift - Wn = W1 - (x2 - x1) - xn = x2 - elif (x2 < x1) and (x1 + W1 > x2 + W2): # Negative shift and higher than W2 - Wn = W2 + (x2 - x1) - xn = x1 - else: - Wn = W1 - xn = x1 - - return int(yn), int(xn), int(Hn), int(Wn) - - -def cut_image(image, rect): - """Cuts the part of the image array defined by rectangle""" - return image[-(rect[2]) : -(rect[0] + 1), rect[1] : rect[1] + rect[3]] diff --git a/pySNOM/readers.py b/pySNOM/readers.py index 382b575..31abff8 100644 --- a/pySNOM/readers.py +++ b/pySNOM/readers.py @@ -340,7 +340,8 @@ def read(self): params["Scan"] = "Fourier Scan" return data, params - + + class ImageStackXYZReader(Reader): """Reads a list of images from the subfolders of the specified folder by loading the files that contain the pattern string int the filename""" @@ -348,9 +349,8 @@ def __init__(self, fullfilepath=None): super().__init__(fullfilepath) def read(self): - if self.filename is None: - raise ValueError('No folder specified') + raise ValueError("No folder specified") else: with open(self.filename, encoding="utf8") as f: x = next(f) # header @@ -359,17 +359,17 @@ def read(self): f.seek(0) next(f) - datacols = np.arange(2, len(x)+2) + datacols = np.arange(2, len(x) + 2) C_data = np.loadtxt(f, dtype="float", usecols=datacols) - + f.seek(0) next(f) metacols = np.arange(0, 2) meta = np.loadtxt( - f, - dtype={"names": ('Row','Column'), "formats": (float,float)}, - usecols=metacols, + f, + dtype={"names": ("Row", "Column"), "formats": (float, float)}, + usecols=metacols, ) Max_row = len(np.unique(meta["Row"])) diff --git a/pySNOM/spectra.py b/pySNOM/spectra.py index a921b54..1613df2 100644 --- a/pySNOM/spectra.py +++ b/pySNOM/spectra.py @@ -217,10 +217,13 @@ def __init__(self, datatype=DataTypes.Phase, dounwrap=False): def transform(self, spectrum, refspectrum): if self.datatype == DataTypes.Phase or self.datatype == DataTypes.Topography: + newspectrum = np.angle( + np.exp(spectrum * complex(1j)) / np.exp(refspectrum * complex(1j)) + ) if self.dounwrap: - return np.unwrap(spectrum - refspectrum) + return np.unwrap(newspectrum) else: - return spectrum - refspectrum + return newspectrum else: return spectrum / refspectrum diff --git a/pySNOM/tests/test_transform.py b/pySNOM/tests/test_transform.py index f532479..d8b9613 100644 --- a/pySNOM/tests/test_transform.py +++ b/pySNOM/tests/test_transform.py @@ -279,13 +279,12 @@ def test_stackalignment(self): image1[10:40, 10:40] = 1 image2[20:50, 20:50] = 1 - aligner = AlignImageStack() - shifts, crossrect = aligner.calculate([image1, image2]) - np.testing.assert_equal(shifts, [np.asarray([-10.0, -10.0])]) - np.testing.assert_equal(crossrect, [10, 0, 40, 90]) + aligner = AlignImageStack(ref_index=0,upsample_factor=1) + shifts = aligner.calculate([image1, image2]) + aligned_stack = aligner.transform(images=[image1, image2],shifts=shifts) - out = aligner.transform([image1, image2], shifts, crossrect) - np.testing.assert_equal(np.shape(out), (2, 29, 90)) + np.testing.assert_equal(shifts[1], np.asarray([-10.0, -10.0])) + np.testing.assert_array_equal(np.round(aligned_stack[0]), np.round(aligned_stack[1])) class TestHelperFunctions(unittest.TestCase):