-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
393 lines (329 loc) · 13.9 KB
/
app.py
File metadata and controls
393 lines (329 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import streamlit as st
from PIL import Image, ImageDraw
import numpy as np
import time
import os
Image.MAX_IMAGE_PIXELS = None
# Usage:
# conda activate aegle_patch_viewer
# streamlit run app.py --server.headless true
mock_flg = False
image_path = "/Users/kuangda/Developer/1-projects/4-codex-analysis/0-phenocycler-penntmc-pipeline/aegle_patch_viewer/data/extended_extracted_channel_image.png"
# image_path = "/Users/kuangda/Developer/1-projects/4-codex-analysis/0-phenocycler-penntmc-pipeline/aegle_patch_viewer/NW_1_Scan1_rgb.png"
# image_path = "/Users/kuangda/Developer/1-projects/4-codex-analysis/0-phenocycler-penntmc-pipeline/aegle_patch_viewer/NW_1_Scan1_dev_rgb.png"
# Key parameters
patch_height = 1000
patch_width = 1000
overlap = 0.1
# patch_height = 5000
# patch_width = 5000
# patch_height = 10000
# patch_width = 10000
# patch_height = 20000
# patch_width = 20000
# overlap = 0.0
def extend_image(image, patch_height, patch_width, step_height, step_width):
"""
# Extend the image to ensure full coverage when cropping patches.
Args:
image (np.ndarray): The image to extend.
patch_height (int): Height of each patch.
patch_width (int): Width of each patch.
step_height (int): Step size in height.
step_width (int): Step size in width.
Returns:
np.ndarray: Extended image.
"""
img_height, img_width, _ = image.shape
pad_height = (
patch_height - (img_height - patch_height) % step_height
) % patch_height
pad_width = (patch_width - (img_width - patch_width) % step_width) % patch_width
extended_image = np.pad(
image,
((0, pad_height), (0, pad_width), (0, 0)),
mode="constant",
constant_values=0,
)
print(f"Extended image shape: {extended_image.shape}")
return extended_image
# Function to create a mock image
def create_mock_image(width=500, height=500, color=(255, 255, 255)):
# Create a blank white image
img = Image.new("RGB", (width, height), color=color)
# Draw some shapes or patterns (optional)
draw = ImageDraw.Draw(img)
for i in range(0, width, 50):
draw.line((i, 0, i, height), fill=(200, 200, 200))
for j in range(0, height, 50):
draw.line((0, j, width, j), fill=(200, 200, 200))
return img
# Function to generate mock patch mapping
def generate_mock_patches(image_size, patch_size=100, overlap=0):
width, height = image_size
patches = {}
index = 1
step = patch_size - overlap
for y in range(0, height, step):
for x in range(0, width, step):
x1 = x
y1 = y
x2 = min(x + patch_size, width)
y2 = min(y + patch_size, height)
patches[index] = (x1, y1, x2, y2)
index += 1
return patches
@st.cache_data
def read_image(
image_path="NW_1_Scan1_rgb.png",
patch_height=1440,
patch_width=1920,
step_height=100,
step_width=100,
):
"""
Profile the image reading and processing time.
On my linux desktop, the time to read a 28800x50400 image is about 140 seconds.
```
image.size: (28800, 50400)
Time to open image: 0.00 seconds
image.shape: (50400, 28800, 3)
Time to transform image to numpy array: 27.86 seconds
Extended image shape: (50832, 29760, 3)
(50832, 29760, 3)
Time to extend image: 137.57 seconds
Time to convert back to PIL Image: 140.83 seconds
Total time for read_image: 140.83 seconds
On my macbook, the time to read a 28800x50400 image is about 34 seconds.
```
image.size: (28800, 50400)
Time to open image: 0.01 seconds
image.shape: (50400, 28800, 3)
Time to transform image to numpy array: 32.16 seconds
Extended image shape: (50832, 29760, 3)
(50832, 29760, 3)
Time to extend image: 32.95 seconds
Time to convert back to PIL Image: 34.72 seconds
Total time for read_image: 34.72 seconds
```
"""
# Check if the image exists
if not os.path.exists(image_path):
print(f"Image not found at: {image_path}")
return None
start_time = time.time()
# Step 1: Open the image
image = Image.open(image_path)
# chop the image to 1/2 along vertical axis
# image = image.crop((0, 0, image.width, image.height // 2))
print(f"image.size: {image.size}")
step_time = time.time()
print(f"Time to open image: {step_time - start_time:.2f} seconds")
# transform image to numpy array
# Step 2: Transform image to numpy array
image = np.asarray(image)
print(f"image.shape: {image.shape}")
step_time = time.time()
print(
f"Time to transform image to numpy array: {step_time - start_time:.2f} seconds"
)
# Step 3: Extend the image to ensure full coverage
extended_image_arr = extend_image(
image, patch_height, patch_width, step_height, step_width
)
print(extended_image_arr.shape)
step_time = time.time()
print(f"Time to extend image: {step_time - start_time:.2f} seconds")
# Step 4: Convert back to PIL Image
extended_image = Image.fromarray(extended_image_arr)
step_time = time.time()
print(f"Time to convert back to PIL Image: {step_time - start_time:.2f} seconds")
total_time = time.time()
print(f"Total time for read_image: {total_time - start_time:.2f} seconds")
return extended_image
def generate_patches(
img_height, img_width, patch_height, patch_width, step_height, step_width
):
index = 0
patches = {}
for y in range(0, img_height - patch_height + 1, step_height):
for x in range(0, img_width - patch_width + 1, step_width):
x1 = x
y1 = y
x2 = min(x + patch_width, img_width)
y2 = min(y + patch_height, img_height)
patches[index] = (x1, y1, x2, y2)
index += 1
return patches
def main():
print("---------- Starting Streamlit app...")
st.title("Image Patch Visualizer")
# Create sidebar for parameter configuration
st.sidebar.header("Patch Configuration")
# Initialize session state for parameters if they don't exist
if "patch_height" not in st.session_state:
st.session_state.patch_height = 1000
if "patch_width" not in st.session_state:
st.session_state.patch_width = 1000
if "overlap" not in st.session_state:
st.session_state.overlap = 0.1
if "use_custom" not in st.session_state:
st.session_state.use_custom = False
# Predefined options for patch dimensions
height_options = [500, 1000, 1440, 5000, 10000, 20000]
width_options = [500, 1000, 1920, 5000, 10000, 20000]
overlap_options = [0.0, 0.1, 0.2, 0.3, 0.5]
# Option for custom values
use_custom = st.sidebar.checkbox("Use custom values", value=st.session_state.use_custom, key="use_custom")
# Set patch height
if use_custom:
patch_height = st.sidebar.number_input("Custom Patch Height", min_value=100, max_value=50000, value=st.session_state.patch_height, step=100, key="patch_height")
else:
height_index = height_options.index(st.session_state.patch_height) if st.session_state.patch_height in height_options else 1
patch_height = st.sidebar.selectbox("Patch Height", height_options, index=height_index, key="patch_height")
# Set patch width
if use_custom:
patch_width = st.sidebar.number_input("Custom Patch Width", min_value=100, max_value=50000, value=st.session_state.patch_width, step=100, key="patch_width")
else:
width_index = width_options.index(st.session_state.patch_width) if st.session_state.patch_width in width_options else 1
patch_width = st.sidebar.selectbox("Patch Width", width_options, index=width_index, key="patch_width")
# Set overlap
if use_custom:
overlap = st.sidebar.number_input("Custom Overlap (0.0-0.9)", min_value=0.0, max_value=0.9, value=st.session_state.overlap, step=0.05, format="%.2f", key="overlap")
else:
overlap_index = overlap_options.index(st.session_state.overlap) if st.session_state.overlap in overlap_options else 1
overlap = st.sidebar.selectbox("Overlap", overlap_options, index=overlap_index, key="overlap")
# Calculate step sizes based on selected parameters
overlap_height = int(patch_height * overlap)
overlap_width = int(patch_width * overlap)
# Calculate step size for cropping
step_height = patch_height - overlap_height
step_width = patch_width - overlap_width
# Display current configuration
st.sidebar.markdown("---")
st.sidebar.markdown(f"**Current Configuration:**")
st.sidebar.markdown(f"Patch Size: {patch_height}×{patch_width} px")
st.sidebar.markdown(f"Overlap: {overlap:.2f} ({overlap_height}×{overlap_width} px)")
st.sidebar.markdown(f"Step Size: {step_height}×{step_width} px")
# Add buttons to apply changes or reset to defaults
st.sidebar.markdown("---")
col1, col2 = st.sidebar.columns(2)
with col1:
if st.button("Apply Changes"):
# The changes are automatically applied due to session state
st.rerun()
with col2:
if st.button("Reset Parameters"):
# Reset to default values
st.session_state.patch_height = 1000
st.session_state.patch_width = 1000
st.session_state.overlap = 0.1
st.session_state.use_custom = False
st.rerun()
# Read and process the image
if mock_flg:
original_image = create_mock_image()
patch_mapping = generate_mock_patches(original_image.size)
else:
original_image = read_image(
image_path, patch_height, patch_width, step_height, step_width
)
img_width, img_height = original_image.size
patch_mapping = generate_patches(
img_height, img_width, patch_height, patch_width, step_height, step_width
)
# Get original image dimensions
original_width, original_height = original_image.size
# Scale the image for display purposes
display_width = 800 # Adjust as needed
scale_factor = display_width / original_width
display_height = int(original_height * scale_factor)
display_image = original_image.resize(
(display_width, display_height), Image.LANCZOS
)
# Update the label to include the range of indices
min_index = min(patch_mapping.keys())
max_index = max(patch_mapping.keys())
# Initialize session state for showing all patches
if "show_all_patches" not in st.session_state:
st.session_state.show_all_patches = False
# Initialize session state for clearing selections
if "clear_selections" not in st.session_state:
st.session_state.clear_selections = False
# Create buttons side by side using columns
col1, col2 = st.columns(2)
with col1:
if st.button("Show All Patches"):
st.session_state.show_all_patches = True
st.session_state.clear_selections = False
st.rerun()
with col2:
if st.button("Reset to Default"):
st.session_state.clear_selections = True
st.session_state.show_all_patches = False
st.rerun()
# Determine default selection based on session state
if st.session_state.clear_selections:
default_selection = [min_index]
elif st.session_state.show_all_patches:
default_selection = list(patch_mapping.keys())
else:
default_selection = [min_index]
# Allow the user to select multiple patch indices
selected_indices = st.multiselect(
f"Select Patch Indices ({min_index} - {max_index}):",
options=list(patch_mapping.keys()),
default=default_selection,
key="multiselect"
)
# Reset the session states if user manually changes selection
if selected_indices and st.session_state.clear_selections:
st.session_state.clear_selections = False
elif selected_indices != list(patch_mapping.keys()) and st.session_state.show_all_patches:
st.session_state.show_all_patches = False
if selected_indices:
# Create a copy of the scaled image to draw on
image_with_bboxes = display_image.copy()
draw = ImageDraw.Draw(image_with_bboxes)
# Prepare a list of colors to cycle through
colors = ["red", "blue", "yellow", "purple"]
# Collect patches to display later
patches_to_display = []
# Loop over selected indices and draw bounding boxes
for idx_num, idx in enumerate(selected_indices):
if idx in patch_mapping:
bbox = patch_mapping[idx]
# Scale bbox coordinates for the display image
scaled_bbox = [int(coord * scale_factor) for coord in bbox]
# Select color by cycling through the colors list
color = colors[idx_num % len(colors)]
draw.rectangle(
scaled_bbox, outline=color, width=2
) # Reduced width for scaled image
# Store the patch and its color to display later
patch = original_image.crop(bbox)
patches_to_display.append((patch, idx, color))
else:
st.error(f"Patch index {idx} not found.")
# Display the image with all bounding boxes
st.image(
image_with_bboxes,
caption="Selected Patches Location",
use_column_width=True,
)
# Display the patches below the main image in columns
st.subheader("Selected Patches")
num_columns = 3 # Number of patches per row
columns = st.columns(num_columns)
for idx, (patch, idx_num, color) in enumerate(patches_to_display):
col = columns[idx % num_columns]
with col:
st.image(
patch,
caption=f"Patch {idx_num} (Color: {color})",
width=200, # Fixed width for each patch
)
else:
st.warning("No patches selected.")
if __name__ == "__main__":
main()