From ecfd7ccdf157c4b48526587af4ebb142c95032ad Mon Sep 17 00:00:00 2001 From: Gergely Nemeth Date: Tue, 5 May 2026 10:50:49 +0200 Subject: [PATCH 1/5] Complex normalization for phase Co-authored-by: Copilot --- pySNOM/spectra.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pySNOM/spectra.py b/pySNOM/spectra.py index a921b54..758d6f9 100644 --- a/pySNOM/spectra.py +++ b/pySNOM/spectra.py @@ -217,10 +217,11 @@ def __init__(self, datatype=DataTypes.Phase, dounwrap=False): def transform(self, spectrum, refspectrum): if self.datatype == DataTypes.Phase or self.datatype == DataTypes.Topography: + newspectrum = np.angle(np.exp(spectrum * complex(1j)) / np.exp(refspectrum * complex(1j))) if self.dounwrap: - return np.unwrap(spectrum - refspectrum) + return np.unwrap(newspectrum) else: - return spectrum - refspectrum + return newspectrum else: return spectrum / refspectrum From 5c36ec2c32fa73cd983d1947777fcad853b1545e Mon Sep 17 00:00:00 2001 From: Gergely Nemeth Date: Fri, 23 Jan 2026 11:16:52 +0100 Subject: [PATCH 2/5] Fix and docs for image registration --- pySNOM/images.py | 243 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 215 insertions(+), 28 deletions(-) diff --git a/pySNOM/images.py b/pySNOM/images.py index 2009036..0a34eb9 100644 --- a/pySNOM/images.py +++ b/pySNOM/images.py @@ -10,6 +10,7 @@ from scipy.sparse import coo_matrix from scipy.sparse.linalg import spsolve from scipy.ndimage import generic_filter +from scipy.ndimage import shift MeasurementModes = Enum( "MeasurementModes", @@ -562,11 +563,13 @@ def transform(self, image): class CalculateXCorrDrift(Transformation): """Calculates the drift between reference and template image""" - def __init__(self, image_ref): + def __init__(self, image_ref, upsample_factor=100, **kwargs): self.image_ref = image_ref + self.upsample_factor = upsample_factor + self.kwargs = kwargs def transform(self, image): - shift, _, _ = phase_cross_correlation(self.image_ref, image) + shift, _, _ = phase_cross_correlation(self.image_ref, image, upsample_factor=self.upsample_factor, **self.kwargs) return shift @@ -581,45 +584,229 @@ def transform(self, image): offset_phase = np.fft.ifftn(offset_phase) return offset_phase.real - class AlignImageStack(Transformation): - """Calculates the drift between the given images and organize the comman areas into an aligned stack""" + """Align a stack of images by estimating and correcting relative drift. + + This transformation estimates pairwise shifts between a reference image + and subsequent images in a stack, then applies those shifts to produce + an aligned image stack cropped to the common overlapping region. + + Parameters + ---------- + upsample_factor : int, optional + Upsampling factor used in the subpixel drift estimation, controlling + the precision of the cross-correlation-based shift calculation. + Higher values yield more precise but slower computations. + + Attributes + ---------- + upsample_factor : int + The upsampling factor used when computing image drift. + + Methods + ------- + calculate(images) + Estimate the drift between a reference image and the remaining images + in the stack, returning a list of (x, y) shift vectors. + transform(images, shifts) + Apply the given shifts to the image stack, correct for drift, and + crop to the common cross-section. + """ - def __init__(self): - pass + def __init__(self, ref_index=0, upsample_factor=100): + self.ref_index = ref_index + self.upsample_factor = upsample_factor def calculate(self, images): + """ + Estimate drift vectors for a stack of images. + + The first image in `images` is used as the reference. Drift for each + subsequent image is computed via cross-correlation against this + reference, returning the corresponding (x, y) shifts. + + Parameters + ---------- + images : sequence of array_like + Sequence of 2D images to be aligned. The first image serves as + the reference for drift estimation. + + Returns + ------- + list of tuple or None + List of (x, y) shift vectors for each image after the first, + where each tuple represents the drift relative to the reference + image. Returns ``None`` if fewer than two images are provided. + """ + shifts = [] - crossrect = [0, 0, np.shape(images[0])[0], np.shape(images[0])[1]] if len(images) > 1: - xcorr = CalculateXCorrDrift(images[0]) + xcorr = CalculateXCorrDrift(images[self.ref_index], upsample_factor=self.upsample_factor) for i in range(len(images)): - if i > 0: + if i != self.ref_index: shifts.append(xcorr.transform(images[i])) - crossrect = shifted_cross_section( - rect1=crossrect, - rect2=[ - -shifts[-1][0], - shifts[-1][1], - np.shape(images[i])[0], - np.shape(images[i])[1], - ], - ) - return shifts, crossrect + else: + shifts.append((0, 0)) + + return shifts else: return None - def transform(self, images, shifts, crossrect): - aligned_stack = [] - for i in range(len(images)): - if i > 0: - shifter = CorrectImageDrift(shifts[i - 1]) - aligned_stack.append(shifter.transform(images[i])) - aligned_stack[i] = cut_image(aligned_stack[i], crossrect) - else: - aligned_stack.append(cut_image(images[i], crossrect)) + def transform(self, images, shifts): + """ + Apply drift correction and crop to common overlap. + + Each image after the first is shifted according to the corresponding + drift vector in `shifts`, while the first image is kept unchanged. + After correction, all images are cropped to their common overlapping + cross-section. + + Parameters + ---------- + images : sequence of array_like + Sequence of 2D images to be aligned, ordered consistently with + the `shifts` list. + shifts : sequence of tuple + Sequence of (x, y) drift vectors as returned by `calculate`, + where each element corresponds to the image at the same index + after the first. + + Returns + ------- + numpy.ndarray + Array containing the aligned and cropped image stack, with shape + (N, H_out, W_out), where N is the number of images and + ``H_out``, ``W_out`` are the dimensions of the common + overlapping region. + """ + + # aligned_stack = [] + aligned_stack = np.zeros((len(images),) + images[0].shape, dtype=images[0].dtype) + for k in range(len(images)): + aligned_stack[k] = shift_fill(images[k], shifts[k],fill=0) + + # for i in range(len(images)): + # shifter = CorrectImageDrift(shifts[i]) + # aligned_stack.append(shifter.transform(images[i])) + + aligned_stack = cut_cross_section(aligned_stack, shifts) + return aligned_stack +def cut_cross_section(image_stack, shifts): + """ + Crop a stack of shifted images to their common overlapping cross-section. + + This function takes an image stack (e.g., a list or array of 2D images) and + a corresponding list of shift vectors applied to each image. It computes + the minimal rectangular region (cross-section) that is common to all shifted + images and returns the cropped stack. + + Parameters + ---------- + image_stack : array_like + A sequence or NumPy array of shape (N, H, W), where N is the number + of images, and H and W are the image height and width, respectively. + shifts : array_like + A sequence of (x, y) shift vectors of shape (N, 2), specifying + the pixel displacements applied to each corresponding image. + + Returns + ------- + numpy.ndarray + A NumPy array of the cropped image stack containing only the + overlapping region across all shifted images. The shape of the + returned array is (N, H', W'), where H' and W' correspond to + the dimensions of the common cross-section. + + Notes + ----- + - This assumes that the shifts are given in pixel units for the x + (horizontal) and y (vertical) directions. + - The function ensures that all images are trimmed to the same region + to maintain alignment across the stack. + """ + + # Calculate the cross-section of all shifted images + xmin, ymin = np.array(shifts)[:, 1].min(), np.array(shifts)[:, 0].min() + xmax, ymax = np.array(shifts)[:, 1].max(), np.array(shifts)[:, 0].max() + xmin, xmax = int(round(xmin)), int(round(xmax)) + ymin, ymax = int(round(ymin)), int(round(ymax)) + + # Cut the images to the common cross-section + shape = np.array(image_stack).shape + print(shape) + slicex = slice(max(xmax, 0), min(shape[2], shape[2]+xmin)) + slicey = slice(max(ymax, 0), min(shape[1], shape[1]+ymin)) + print(slicey,slicex) + + return np.array(image_stack)[:,slicey, slicex] + + +############################################################################################### +class RegisterTranslation: + def __init__(self, upsample_factor=1): + self.upsample_factor = upsample_factor + + def __call__(self, base, shifted): + """Return the shift (in each axis) needed to align to the base. + Shift down and right are positive. First coordinate belongs to + the first axis (rows in numpy).""" + s, _, _ = phase_cross_correlation( + base, shifted, upsample_factor=self.upsample_factor + ) + return s + + +def shift_fill(img, sh, fill=np.nan): + """Shift and fill invalid positions""" + aligned = shift(img, sh, mode='nearest') + + (u, v) = img.shape + + shifty = int(round(sh[0])) + aligned[:max(0, shifty), :] = fill + aligned[min(u, u+shifty):, :] = fill + + shiftx = int(round(sh[1])) + aligned[:, :max(0, shiftx)] = fill + aligned[:, min(v, v+shiftx):] = fill + + return aligned + + +def alignstack(raw, shiftfn, ref_frame_num=0): + """Align to the first image""" + shifts = calculate_stack_shifts(raw, shiftfn, ref_frame_num=ref_frame_num) + aligned = alignstack_with_shifts(raw, shifts) + + return shifts, aligned + + +def calculate_stack_shifts(raw, shiftfn, ref_frame_num=0): + """Calculate the shifts for each image in the stack""" + base = raw[ref_frame_num] + shifts = [] + + for i, image in enumerate(raw): + if i != ref_frame_num: + shifts.append(shiftfn(base, image)) + else: + shifts.append((0, 0)) + shifts = np.array(shifts) + + return shifts + + +def alignstack_with_shifts(raw, shifts): + """Aligns the stack using the provided shifts""" + aligned = np.zeros((len(raw),) + raw[0].shape, dtype=raw[0].dtype) + for k in range(len(raw)): + aligned[k] = shift_fill(raw[k], shifts[k]) + + return aligned + +################################################################################################ def sort_image_stack(images, wns): """Sort the image stack based on the wavenumber list""" From 9381d782dbdd4570bd2dc489862b2b30965e17c7 Mon Sep 17 00:00:00 2001 From: Gergely Nemeth Date: Fri, 23 Jan 2026 11:57:14 +0100 Subject: [PATCH 3/5] Update tests --- pySNOM/tests/test_transform.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pySNOM/tests/test_transform.py b/pySNOM/tests/test_transform.py index f532479..d8b9613 100644 --- a/pySNOM/tests/test_transform.py +++ b/pySNOM/tests/test_transform.py @@ -279,13 +279,12 @@ def test_stackalignment(self): image1[10:40, 10:40] = 1 image2[20:50, 20:50] = 1 - aligner = AlignImageStack() - shifts, crossrect = aligner.calculate([image1, image2]) - np.testing.assert_equal(shifts, [np.asarray([-10.0, -10.0])]) - np.testing.assert_equal(crossrect, [10, 0, 40, 90]) + aligner = AlignImageStack(ref_index=0,upsample_factor=1) + shifts = aligner.calculate([image1, image2]) + aligned_stack = aligner.transform(images=[image1, image2],shifts=shifts) - out = aligner.transform([image1, image2], shifts, crossrect) - np.testing.assert_equal(np.shape(out), (2, 29, 90)) + np.testing.assert_equal(shifts[1], np.asarray([-10.0, -10.0])) + np.testing.assert_array_equal(np.round(aligned_stack[0]), np.round(aligned_stack[1])) class TestHelperFunctions(unittest.TestCase): From 497e5778ed28d0f08de55528a286a46f0e501edf Mon Sep 17 00:00:00 2001 From: Gergely Nemeth Date: Sun, 17 May 2026 11:48:12 +0200 Subject: [PATCH 4/5] Lint --- pySNOM/images.py | 60 +++++++++++++++++++++++++++-------------------- pySNOM/readers.py | 16 ++++++------- pySNOM/spectra.py | 4 +++- 3 files changed, 46 insertions(+), 34 deletions(-) diff --git a/pySNOM/images.py b/pySNOM/images.py index 0a34eb9..84521d6 100644 --- a/pySNOM/images.py +++ b/pySNOM/images.py @@ -569,7 +569,9 @@ def __init__(self, image_ref, upsample_factor=100, **kwargs): self.kwargs = kwargs def transform(self, image): - shift, _, _ = phase_cross_correlation(self.image_ref, image, upsample_factor=self.upsample_factor, **self.kwargs) + shift, _, _ = phase_cross_correlation( + self.image_ref, image, upsample_factor=self.upsample_factor, **self.kwargs + ) return shift @@ -584,6 +586,7 @@ def transform(self, image): offset_phase = np.fft.ifftn(offset_phase) return offset_phase.real + class AlignImageStack(Transformation): """Align a stack of images by estimating and correcting relative drift. @@ -638,16 +641,18 @@ def calculate(self, images): where each tuple represents the drift relative to the reference image. Returns ``None`` if fewer than two images are provided. """ - + shifts = [] if len(images) > 1: - xcorr = CalculateXCorrDrift(images[self.ref_index], upsample_factor=self.upsample_factor) + xcorr = CalculateXCorrDrift( + images[self.ref_index], upsample_factor=self.upsample_factor + ) for i in range(len(images)): if i != self.ref_index: shifts.append(xcorr.transform(images[i])) else: shifts.append((0, 0)) - + return shifts else: return None @@ -681,9 +686,11 @@ def transform(self, images, shifts): """ # aligned_stack = [] - aligned_stack = np.zeros((len(images),) + images[0].shape, dtype=images[0].dtype) + aligned_stack = np.zeros( + (len(images),) + images[0].shape, dtype=images[0].dtype + ) for k in range(len(images)): - aligned_stack[k] = shift_fill(images[k], shifts[k],fill=0) + aligned_stack[k] = shift_fill(images[k], shifts[k], fill=0) # for i in range(len(images)): # shifter = CorrectImageDrift(shifts[i]) @@ -693,37 +700,38 @@ def transform(self, images, shifts): return aligned_stack + def cut_cross_section(image_stack, shifts): """ Crop a stack of shifted images to their common overlapping cross-section. - This function takes an image stack (e.g., a list or array of 2D images) and - a corresponding list of shift vectors applied to each image. It computes - the minimal rectangular region (cross-section) that is common to all shifted + This function takes an image stack (e.g., a list or array of 2D images) and + a corresponding list of shift vectors applied to each image. It computes + the minimal rectangular region (cross-section) that is common to all shifted images and returns the cropped stack. Parameters ---------- image_stack : array_like - A sequence or NumPy array of shape (N, H, W), where N is the number + A sequence or NumPy array of shape (N, H, W), where N is the number of images, and H and W are the image height and width, respectively. shifts : array_like - A sequence of (x, y) shift vectors of shape (N, 2), specifying + A sequence of (x, y) shift vectors of shape (N, 2), specifying the pixel displacements applied to each corresponding image. Returns ------- numpy.ndarray - A NumPy array of the cropped image stack containing only the - overlapping region across all shifted images. The shape of the - returned array is (N, H', W'), where H' and W' correspond to + A NumPy array of the cropped image stack containing only the + overlapping region across all shifted images. The shape of the + returned array is (N, H', W'), where H' and W' correspond to the dimensions of the common cross-section. Notes ----- - - This assumes that the shifts are given in pixel units for the x + - This assumes that the shifts are given in pixel units for the x (horizontal) and y (vertical) directions. - - The function ensures that all images are trimmed to the same region + - The function ensures that all images are trimmed to the same region to maintain alignment across the stack. """ @@ -736,11 +744,11 @@ def cut_cross_section(image_stack, shifts): # Cut the images to the common cross-section shape = np.array(image_stack).shape print(shape) - slicex = slice(max(xmax, 0), min(shape[2], shape[2]+xmin)) - slicey = slice(max(ymax, 0), min(shape[1], shape[1]+ymin)) - print(slicey,slicex) + slicex = slice(max(xmax, 0), min(shape[2], shape[2] + xmin)) + slicey = slice(max(ymax, 0), min(shape[1], shape[1] + ymin)) + print(slicey, slicex) - return np.array(image_stack)[:,slicey, slicex] + return np.array(image_stack)[:, slicey, slicex] ############################################################################################### @@ -760,17 +768,17 @@ def __call__(self, base, shifted): def shift_fill(img, sh, fill=np.nan): """Shift and fill invalid positions""" - aligned = shift(img, sh, mode='nearest') + aligned = shift(img, sh, mode="nearest") (u, v) = img.shape shifty = int(round(sh[0])) - aligned[:max(0, shifty), :] = fill - aligned[min(u, u+shifty):, :] = fill + aligned[: max(0, shifty), :] = fill + aligned[min(u, u + shifty) :, :] = fill shiftx = int(round(sh[1])) - aligned[:, :max(0, shiftx)] = fill - aligned[:, min(v, v+shiftx):] = fill + aligned[:, : max(0, shiftx)] = fill + aligned[:, min(v, v + shiftx) :] = fill return aligned @@ -806,8 +814,10 @@ def alignstack_with_shifts(raw, shifts): return aligned + ################################################################################################ + def sort_image_stack(images, wns): """Sort the image stack based on the wavenumber list""" diff --git a/pySNOM/readers.py b/pySNOM/readers.py index 382b575..31abff8 100644 --- a/pySNOM/readers.py +++ b/pySNOM/readers.py @@ -340,7 +340,8 @@ def read(self): params["Scan"] = "Fourier Scan" return data, params - + + class ImageStackXYZReader(Reader): """Reads a list of images from the subfolders of the specified folder by loading the files that contain the pattern string int the filename""" @@ -348,9 +349,8 @@ def __init__(self, fullfilepath=None): super().__init__(fullfilepath) def read(self): - if self.filename is None: - raise ValueError('No folder specified') + raise ValueError("No folder specified") else: with open(self.filename, encoding="utf8") as f: x = next(f) # header @@ -359,17 +359,17 @@ def read(self): f.seek(0) next(f) - datacols = np.arange(2, len(x)+2) + datacols = np.arange(2, len(x) + 2) C_data = np.loadtxt(f, dtype="float", usecols=datacols) - + f.seek(0) next(f) metacols = np.arange(0, 2) meta = np.loadtxt( - f, - dtype={"names": ('Row','Column'), "formats": (float,float)}, - usecols=metacols, + f, + dtype={"names": ("Row", "Column"), "formats": (float, float)}, + usecols=metacols, ) Max_row = len(np.unique(meta["Row"])) diff --git a/pySNOM/spectra.py b/pySNOM/spectra.py index 758d6f9..1613df2 100644 --- a/pySNOM/spectra.py +++ b/pySNOM/spectra.py @@ -217,7 +217,9 @@ def __init__(self, datatype=DataTypes.Phase, dounwrap=False): def transform(self, spectrum, refspectrum): if self.datatype == DataTypes.Phase or self.datatype == DataTypes.Topography: - newspectrum = np.angle(np.exp(spectrum * complex(1j)) / np.exp(refspectrum * complex(1j))) + newspectrum = np.angle( + np.exp(spectrum * complex(1j)) / np.exp(refspectrum * complex(1j)) + ) if self.dounwrap: return np.unwrap(newspectrum) else: From 34d98be9b3bc9bc8b415583ab034b6f673feb29f Mon Sep 17 00:00:00 2001 From: Gergely Nemeth Date: Sun, 17 May 2026 12:19:55 +0200 Subject: [PATCH 5/5] Removed unused code --- pySNOM/images.py | 91 ------------------------------------------------ 1 file changed, 91 deletions(-) diff --git a/pySNOM/images.py b/pySNOM/images.py index 84521d6..5fd42ff 100644 --- a/pySNOM/images.py +++ b/pySNOM/images.py @@ -685,17 +685,12 @@ def transform(self, images, shifts): overlapping region. """ - # aligned_stack = [] aligned_stack = np.zeros( (len(images),) + images[0].shape, dtype=images[0].dtype ) for k in range(len(images)): aligned_stack[k] = shift_fill(images[k], shifts[k], fill=0) - # for i in range(len(images)): - # shifter = CorrectImageDrift(shifts[i]) - # aligned_stack.append(shifter.transform(images[i])) - aligned_stack = cut_cross_section(aligned_stack, shifts) return aligned_stack @@ -751,21 +746,6 @@ def cut_cross_section(image_stack, shifts): return np.array(image_stack)[:, slicey, slicex] -############################################################################################### -class RegisterTranslation: - def __init__(self, upsample_factor=1): - self.upsample_factor = upsample_factor - - def __call__(self, base, shifted): - """Return the shift (in each axis) needed to align to the base. - Shift down and right are positive. First coordinate belongs to - the first axis (rows in numpy).""" - s, _, _ = phase_cross_correlation( - base, shifted, upsample_factor=self.upsample_factor - ) - return s - - def shift_fill(img, sh, fill=np.nan): """Shift and fill invalid positions""" aligned = shift(img, sh, mode="nearest") @@ -783,38 +763,6 @@ def shift_fill(img, sh, fill=np.nan): return aligned -def alignstack(raw, shiftfn, ref_frame_num=0): - """Align to the first image""" - shifts = calculate_stack_shifts(raw, shiftfn, ref_frame_num=ref_frame_num) - aligned = alignstack_with_shifts(raw, shifts) - - return shifts, aligned - - -def calculate_stack_shifts(raw, shiftfn, ref_frame_num=0): - """Calculate the shifts for each image in the stack""" - base = raw[ref_frame_num] - shifts = [] - - for i, image in enumerate(raw): - if i != ref_frame_num: - shifts.append(shiftfn(base, image)) - else: - shifts.append((0, 0)) - shifts = np.array(shifts) - - return shifts - - -def alignstack_with_shifts(raw, shifts): - """Aligns the stack using the provided shifts""" - aligned = np.zeros((len(raw),) + raw[0].shape, dtype=raw[0].dtype) - for k in range(len(raw)): - aligned[k] = shift_fill(raw[k], shifts[k]) - - return aligned - - ################################################################################################ @@ -882,42 +830,3 @@ def flatten_stack(imagestack): (imagestack.shape[0], imagestack.shape[1] * imagestack.shape[2]) ) return np.ravel(flattened_values, order="F") - - -def shifted_cross_section(rect1: list, rect2: list): - """Calculates the cross-section of two rectangle shifted to each other""" - x1 = rect1[1] - x2 = rect2[1] - y1 = rect1[0] - y2 = rect2[0] - W1 = rect1[3] - W2 = rect2[3] - H1 = rect1[2] - H2 = rect2[2] - - if y2 > y1: - Hn = H1 - (y2 - y1) - yn = y2 - elif (y2 < y1) and (y1 + H1 > y2 + H2): # Negative shift and higher than H2 - Hn = H2 + (y2 - y1) - yn = y1 - else: - Hn = H1 - yn = y1 - - if x2 > x1: # Positive shift - Wn = W1 - (x2 - x1) - xn = x2 - elif (x2 < x1) and (x1 + W1 > x2 + W2): # Negative shift and higher than W2 - Wn = W2 + (x2 - x1) - xn = x1 - else: - Wn = W1 - xn = x1 - - return int(yn), int(xn), int(Hn), int(Wn) - - -def cut_image(image, rect): - """Cuts the part of the image array defined by rectangle""" - return image[-(rect[2]) : -(rect[0] + 1), rect[1] : rect[1] + rect[3]]