Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 172 additions & 66 deletions pySNOM/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import spsolve
from scipy.ndimage import generic_filter
from scipy.ndimage import shift

MeasurementModes = Enum(
"MeasurementModes",
Expand Down Expand Up @@ -562,11 +563,15 @@ def transform(self, image):
class CalculateXCorrDrift(Transformation):
"""Calculates the drift between reference and template image"""

def __init__(self, image_ref):
def __init__(self, image_ref, upsample_factor=100, **kwargs):
self.image_ref = image_ref
self.upsample_factor = upsample_factor
self.kwargs = kwargs

def transform(self, image):
shift, _, _ = phase_cross_correlation(self.image_ref, image)
shift, _, _ = phase_cross_correlation(
self.image_ref, image, upsample_factor=self.upsample_factor, **self.kwargs
)
return shift


Expand All @@ -583,44 +588,184 @@ def transform(self, image):


class AlignImageStack(Transformation):
"""Calculates the drift between the given images and organize the comman areas into an aligned stack"""
"""Align a stack of images by estimating and correcting relative drift.

This transformation estimates pairwise shifts between a reference image
and subsequent images in a stack, then applies those shifts to produce
an aligned image stack cropped to the common overlapping region.

Parameters
----------
upsample_factor : int, optional
Upsampling factor used in the subpixel drift estimation, controlling
the precision of the cross-correlation-based shift calculation.
Higher values yield more precise but slower computations.

Attributes
----------
upsample_factor : int
The upsampling factor used when computing image drift.

Methods
-------
calculate(images)
Estimate the drift between a reference image and the remaining images
in the stack, returning a list of (x, y) shift vectors.
transform(images, shifts)
Apply the given shifts to the image stack, correct for drift, and
crop to the common cross-section.
"""

def __init__(self):
pass
def __init__(self, ref_index=0, upsample_factor=100):
self.ref_index = ref_index
self.upsample_factor = upsample_factor

def calculate(self, images):
"""
Estimate drift vectors for a stack of images.

The first image in `images` is used as the reference. Drift for each
subsequent image is computed via cross-correlation against this
reference, returning the corresponding (x, y) shifts.

Parameters
----------
images : sequence of array_like
Sequence of 2D images to be aligned. The first image serves as
the reference for drift estimation.

Returns
-------
list of tuple or None
List of (x, y) shift vectors for each image after the first,
where each tuple represents the drift relative to the reference
image. Returns ``None`` if fewer than two images are provided.
"""

shifts = []
crossrect = [0, 0, np.shape(images[0])[0], np.shape(images[0])[1]]
if len(images) > 1:
xcorr = CalculateXCorrDrift(images[0])
xcorr = CalculateXCorrDrift(
images[self.ref_index], upsample_factor=self.upsample_factor
)
for i in range(len(images)):
if i > 0:
if i != self.ref_index:
shifts.append(xcorr.transform(images[i]))
crossrect = shifted_cross_section(
rect1=crossrect,
rect2=[
-shifts[-1][0],
shifts[-1][1],
np.shape(images[i])[0],
np.shape(images[i])[1],
],
)
return shifts, crossrect
else:
shifts.append((0, 0))

return shifts
else:
return None

def transform(self, images, shifts, crossrect):
aligned_stack = []
for i in range(len(images)):
if i > 0:
shifter = CorrectImageDrift(shifts[i - 1])
aligned_stack.append(shifter.transform(images[i]))
aligned_stack[i] = cut_image(aligned_stack[i], crossrect)
else:
aligned_stack.append(cut_image(images[i], crossrect))
def transform(self, images, shifts):
"""
Apply drift correction and crop to common overlap.

Each image after the first is shifted according to the corresponding
drift vector in `shifts`, while the first image is kept unchanged.
After correction, all images are cropped to their common overlapping
cross-section.

Parameters
----------
images : sequence of array_like
Sequence of 2D images to be aligned, ordered consistently with
the `shifts` list.
shifts : sequence of tuple
Sequence of (x, y) drift vectors as returned by `calculate`,
where each element corresponds to the image at the same index
after the first.

Returns
-------
numpy.ndarray
Array containing the aligned and cropped image stack, with shape
(N, H_out, W_out), where N is the number of images and
``H_out``, ``W_out`` are the dimensions of the common
overlapping region.
"""

aligned_stack = np.zeros(
(len(images),) + images[0].shape, dtype=images[0].dtype
)
for k in range(len(images)):
aligned_stack[k] = shift_fill(images[k], shifts[k], fill=0)

aligned_stack = cut_cross_section(aligned_stack, shifts)

return aligned_stack


def cut_cross_section(image_stack, shifts):
"""
Crop a stack of shifted images to their common overlapping cross-section.

This function takes an image stack (e.g., a list or array of 2D images) and
a corresponding list of shift vectors applied to each image. It computes
the minimal rectangular region (cross-section) that is common to all shifted
images and returns the cropped stack.

Parameters
----------
image_stack : array_like
A sequence or NumPy array of shape (N, H, W), where N is the number
of images, and H and W are the image height and width, respectively.
shifts : array_like
A sequence of (x, y) shift vectors of shape (N, 2), specifying
the pixel displacements applied to each corresponding image.

Returns
-------
numpy.ndarray
A NumPy array of the cropped image stack containing only the
overlapping region across all shifted images. The shape of the
returned array is (N, H', W'), where H' and W' correspond to
the dimensions of the common cross-section.

Notes
-----
- This assumes that the shifts are given in pixel units for the x
(horizontal) and y (vertical) directions.
- The function ensures that all images are trimmed to the same region
to maintain alignment across the stack.
"""

# Calculate the cross-section of all shifted images
xmin, ymin = np.array(shifts)[:, 1].min(), np.array(shifts)[:, 0].min()
xmax, ymax = np.array(shifts)[:, 1].max(), np.array(shifts)[:, 0].max()
xmin, xmax = int(round(xmin)), int(round(xmax))
ymin, ymax = int(round(ymin)), int(round(ymax))

# Cut the images to the common cross-section
shape = np.array(image_stack).shape
print(shape)
slicex = slice(max(xmax, 0), min(shape[2], shape[2] + xmin))
slicey = slice(max(ymax, 0), min(shape[1], shape[1] + ymin))
print(slicey, slicex)

return np.array(image_stack)[:, slicey, slicex]


def shift_fill(img, sh, fill=np.nan):
"""Shift and fill invalid positions"""
aligned = shift(img, sh, mode="nearest")

(u, v) = img.shape

shifty = int(round(sh[0]))
aligned[: max(0, shifty), :] = fill
aligned[min(u, u + shifty) :, :] = fill

shiftx = int(round(sh[1]))
aligned[:, : max(0, shiftx)] = fill
aligned[:, min(v, v + shiftx) :] = fill

return aligned


################################################################################################


def sort_image_stack(images, wns):
"""Sort the image stack based on the wavenumber list"""

Expand Down Expand Up @@ -685,42 +830,3 @@ def flatten_stack(imagestack):
(imagestack.shape[0], imagestack.shape[1] * imagestack.shape[2])
)
return np.ravel(flattened_values, order="F")


def shifted_cross_section(rect1: list, rect2: list):
"""Calculates the cross-section of two rectangle shifted to each other"""
x1 = rect1[1]
x2 = rect2[1]
y1 = rect1[0]
y2 = rect2[0]
W1 = rect1[3]
W2 = rect2[3]
H1 = rect1[2]
H2 = rect2[2]

if y2 > y1:
Hn = H1 - (y2 - y1)
yn = y2
elif (y2 < y1) and (y1 + H1 > y2 + H2): # Negative shift and higher than H2
Hn = H2 + (y2 - y1)
yn = y1
else:
Hn = H1
yn = y1

if x2 > x1: # Positive shift
Wn = W1 - (x2 - x1)
xn = x2
elif (x2 < x1) and (x1 + W1 > x2 + W2): # Negative shift and higher than W2
Wn = W2 + (x2 - x1)
xn = x1
else:
Wn = W1
xn = x1

return int(yn), int(xn), int(Hn), int(Wn)


def cut_image(image, rect):
"""Cuts the part of the image array defined by rectangle"""
return image[-(rect[2]) : -(rect[0] + 1), rect[1] : rect[1] + rect[3]]
16 changes: 8 additions & 8 deletions pySNOM/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,17 @@ def read(self):
params["Scan"] = "Fourier Scan"

return data, params



class ImageStackXYZReader(Reader):
"""Reads a list of images from the subfolders of the specified folder by loading the files that contain the pattern string int the filename"""

def __init__(self, fullfilepath=None):
super().__init__(fullfilepath)

def read(self):

if self.filename is None:
raise ValueError('No folder specified')
raise ValueError("No folder specified")
else:
with open(self.filename, encoding="utf8") as f:
x = next(f) # header
Expand All @@ -359,17 +359,17 @@ def read(self):

f.seek(0)
next(f)
datacols = np.arange(2, len(x)+2)
datacols = np.arange(2, len(x) + 2)
C_data = np.loadtxt(f, dtype="float", usecols=datacols)

f.seek(0)
next(f)

metacols = np.arange(0, 2)
meta = np.loadtxt(
f,
dtype={"names": ('Row','Column'), "formats": (float,float)},
usecols=metacols,
f,
dtype={"names": ("Row", "Column"), "formats": (float, float)},
usecols=metacols,
)

Max_row = len(np.unique(meta["Row"]))
Expand Down
7 changes: 5 additions & 2 deletions pySNOM/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,13 @@ def __init__(self, datatype=DataTypes.Phase, dounwrap=False):

def transform(self, spectrum, refspectrum):
if self.datatype == DataTypes.Phase or self.datatype == DataTypes.Topography:
newspectrum = np.angle(
np.exp(spectrum * complex(1j)) / np.exp(refspectrum * complex(1j))
)
if self.dounwrap:
return np.unwrap(spectrum - refspectrum)
return np.unwrap(newspectrum)
else:
return spectrum - refspectrum
return newspectrum
else:
return spectrum / refspectrum

Expand Down
11 changes: 5 additions & 6 deletions pySNOM/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,12 @@ def test_stackalignment(self):
image1[10:40, 10:40] = 1
image2[20:50, 20:50] = 1

aligner = AlignImageStack()
shifts, crossrect = aligner.calculate([image1, image2])
np.testing.assert_equal(shifts, [np.asarray([-10.0, -10.0])])
np.testing.assert_equal(crossrect, [10, 0, 40, 90])
aligner = AlignImageStack(ref_index=0,upsample_factor=1)
shifts = aligner.calculate([image1, image2])
aligned_stack = aligner.transform(images=[image1, image2],shifts=shifts)

out = aligner.transform([image1, image2], shifts, crossrect)
np.testing.assert_equal(np.shape(out), (2, 29, 90))
np.testing.assert_equal(shifts[1], np.asarray([-10.0, -10.0]))
np.testing.assert_array_equal(np.round(aligned_stack[0]), np.round(aligned_stack[1]))


class TestHelperFunctions(unittest.TestCase):
Expand Down
Loading