Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions pySNOM/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from skimage.transform import warp
from skimage.registration import optical_flow_tvl1, phase_cross_correlation
from scipy.ndimage import fourier_shift
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import spsolve
from scipy.ndimage import generic_filter

MeasurementModes = Enum(
"MeasurementModes",
Expand Down Expand Up @@ -372,6 +375,122 @@ def calculate(self, data, mask=None):
def transform(self, data, mask=None):
return MaskedTransformation.transform(self, data, mask=mask)

class LaplaceFillIn(Transformation):
"""
Fill in missing (masked) regions of data using inward
interpolation via Laplace's equation. Handles edge and corner cases.
Original NATLAB code: https://github.com/EvanCzako/image-spike-removal/blob/master/remove_spikes.m
"""

def __init__(self, mask):
self.mask = mask

def transform(self,data):
"""
Parameters:
- data (2D np.ndarray): Input data.
- mask (2D np.ndarray): Boolean mask where True indicates missing values to fill.

Returns:
- filled (2D np.ndarray): Image with missing values filled.
"""

M, N = data.shape
num_pixels = M * N

# Flattened indices
u = np.flatnonzero(self.mask) # masked (unknown) pixels
w = np.flatnonzero(~self.mask) # known pixels

# Neighbor index offsets
u_north = u - 1
u_north = np.where(u % M != 0, u_north, 0) # Wrap prevention for top row
u_east = u + M
u_east = np.where(u_east < num_pixels, u_east, 0)
u_south = u + 1
u_south = np.where((u + 1) % M != 0, u_south, 0)
u_west = u - M
u_west = np.where(u_west >= 0, u_west, 0)

a = np.stack([u_north, u_east, u_south, u_west], axis=1)
b = (a > 0).astype(float)
sum_b = b.sum(axis=1, keepdims=True)
c = -b / np.maximum(sum_b, 1e-12)

# Sparse matrix entries
row_inds = np.concatenate([u, u, u, u, u])
col_inds = np.concatenate([u, u_north, u_east, u_south, u_west])
data_vals = np.concatenate([
np.ones(len(u)),
c[:, 0], c[:, 1], c[:, 2], c[:, 3]
])

# Remove invalid entries
valid = (col_inds >= 0) & (col_inds < num_pixels)
row_inds = row_inds[valid]
col_inds = col_inds[valid]
data_vals = data_vals[valid]

# Include identity rows for known pixels
row_inds = np.concatenate([row_inds, w])
col_inds = np.concatenate([col_inds, w])
data_vals = np.concatenate([data_vals, np.ones(len(w))])

# Build sparse matrix
A = coo_matrix((data_vals, (row_inds, col_inds)), shape=(num_pixels, num_pixels)).tocsr()

# Build RHS vector
b_vec = data.flatten()
b_vec[self.mask.flatten()] = 0

# Solve linear system
x = spsolve(A, b_vec)
filled = x.reshape(data.shape)

return filled

class ValueFillIn(Transformation):
def __init__(self, mask, value):
self.value = value
self.mask = mask

def transform(self, data):
data[np.isnan(self.mask)] = self.value

return data

class RemoveSpikes(Transformation):
def __init__(self, threshold=0.8, absolute_threshold=False, method='laplace', higher=False, value = 1.0):
self.threshold = threshold
self.method = method
self.higher = higher
self.value = value
self.absolute_threshold = absolute_threshold

def transform(self, data):
if not self.absolute_threshold:
norm_data = data/np.nanmedian(data)
else:
norm_data = data

if self.higher:
spike_mask = mask_from_datacondition(norm_data > self.threshold)
else:
spike_mask = mask_from_datacondition(norm_data < self.threshold)

# Use previously defined fill_region to fill spikes
match self.method:
case "laplace":
data = LaplaceFillIn(np.isnan(spike_mask)).transform(data)
case "median":
data = ValueFillIn(spike_mask,np.nanmedian(data)).transform(data)
case "manual":
data = ValueFillIn(spike_mask,self.value).transform(data)
case _:
data = spike_mask*data

return data


class ScarRemoval(Transformation):
def __init__(self, threshold=0.5, flip=False, datatype=DataTypes.Phase):
Expand Down
59 changes: 59 additions & 0 deletions pySNOM/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
SimpleNormalize,
DataTypes,
AlignImageStack,
RemoveSpikes,
ValueFillIn,
ScarRemoval,
mask_from_datacondition,
dict_from_imagestack,
Expand Down Expand Up @@ -212,6 +214,63 @@ def test_min(self):
out = l.transform(d, mask=mask)
np.testing.assert_almost_equal(out, [-1.0, 0.0, 1.0])

class TestFillIn(unittest.TestCase):
def test_value_fillin(self):
d = np.ones([9, 9])
mask = np.random.rand(9,9)
mask[4,4] = np.nan

l = ValueFillIn(mask=mask,value=np.inf)

out = l.transform(d)
np.testing.assert_equal(out[4,4],np.inf)

class TestRemoveSpikes(unittest.TestCase):
def test_remove_laplace(self):
d = np.ones([9, 9])
d[4, 4] = 0.8

l = RemoveSpikes(threshold=0.9,method='laplace')

out = l.transform(d)
np.testing.assert_almost_equal(out[4,4],1.0)

def test_remove_higher_laplace(self):
d = np.ones([9, 9])
d[4, 4] = 1.2

l = RemoveSpikes(threshold=1.1,method='laplace',higher=True)

out = l.transform(d)
np.testing.assert_almost_equal(out[4,4],1.0)

def test_remove_manual(self):
d = np.ones([9, 9])
d[4, 4] = 0.8

l = RemoveSpikes(threshold=0.9,method='manual',value=1.11111)

out = l.transform(d)
np.testing.assert_almost_equal(out[4,4],1.11111)

def test_remove_median(self):
d = np.ones([9, 9])
d[4, 4] = 0.8

l = RemoveSpikes(threshold=0.9,method='median')

out = l.transform(d)
np.testing.assert_almost_equal(out[4,4],1.0)

def test_remove_nan(self):
d = np.ones([9, 9])
d[4, 4] = 0.8

l = RemoveSpikes(threshold=0.9,method='asdfsdfhdfg')

out = l.transform(d)
np.testing.assert_almost_equal(out[4,4],np.nan)


class TestAlignImageStack(unittest.TestCase):
def test_stackalignment(self):
Expand Down