-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextractWeights.py
More file actions
58 lines (48 loc) · 2.65 KB
/
extractWeights.py
File metadata and controls
58 lines (48 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import numpy as np
import tensorflow as tf
from tfModel import ZF_UNET_224
# ------------------------------------------------------------------------------------
# This script extracts weights and parameters from a pretrained Keras U-Net model (ZF_UNET_224).
# It loads the model, iterates through each layer, and saves convolution kernels, biases,
# and batch normalization parameters (gamma, beta, running mean, running variance) as .npy files.
# The convolution kernels are reshaped to match the expected format for downstream use.
# A markdown summary of the model architecture is also generated.
# ------------------------------------------------------------------------------------
# Load the pretrained model and weights
model = ZF_UNET_224()
h5_model_path = "zf_unet_224.h5"
model.load_weights(h5_model_path)
# Directory to save extracted weights
weights_dir = "pretrainedKernels_"
os.makedirs(weights_dir, exist_ok=True)
def model_summary_to_markdown(model, filename='model_summary.md'):
"""
Save the Keras model summary to a markdown file for documentation.
"""
with open(filename, 'w') as f:
f.write('## Model Summary\n\n')
model.summary(print_fn=lambda x: f.write(x + '\n'))
# Save model summary for reference
model_summary_to_markdown(model)
# Iterate through all layers and extract weights
for layer in model.layers:
weights = layer.get_weights()
if weights:
print(f"Layer: {layer.name}, Shape(s): {[w.shape for w in weights]}")
# For Conv2D layers, extract and reshape kernel weights and save bias
if isinstance(layer, tf.keras.layers.Conv2D):
weights, bias = layer.get_weights() # Extract kernel and bias
# Reshape kernel weights from (H, W, In_C, Out_C) to (Out_C, In_C, H, W)
reshaped_weights = np.transpose(weights, (3, 2, 0, 1))
# Save reshaped kernel weights and bias as .npy files
np.save(os.path.join(weights_dir, f"{layer.name}_weights.npy"), reshaped_weights)
np.save(os.path.join(weights_dir, f"{layer.name}_bias.npy"), bias)
# For BatchNormalization layers, extract and save parameters
if isinstance(layer, tf.keras.layers.BatchNormalization):
gamma, beta, moving_mean, moving_variance = layer.get_weights()
np.save(os.path.join(weights_dir, f"{layer.name}_gamma.npy"), gamma)
np.save(os.path.join(weights_dir, f"{layer.name}_beta.npy"), beta)
np.save(os.path.join(weights_dir, f"{layer.name}_rmean.npy"), moving_mean)
np.save(os.path.join(weights_dir, f"{layer.name}_rvar.npy"), moving_variance)
print("Weights saved in 'pretrainedKernels_' folder.")