Created
December 15, 2022 23:52
-
-
Save ianstenbit/5998f993bdea735bce672fb5b0c57ac9 to your computer and use it in GitHub Desktop.
Converting EfficientNet Weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loaded with python -i ../path/to/old/efficientnet/impl | |
old_model = EfficientNetV2S(include_top=True, classes=1000, include_rescaling=False, weights="./classification-v0.h5") | |
img = tf.random.normal((1, 512, 512, 3)) | |
import keras_cv | |
import numpy | |
import h5py | |
weights = "classification-v0.h5" | |
f1 = h5py.File(weights,'r+') | |
def traverse_datasets(hdf_file): | |
def h5py_dataset_iterator(g, prefix=''): | |
for key in g.keys(): | |
item = g[key] | |
path = f'{prefix}/{key}' | |
if isinstance(item, h5py.Dataset): # test for dataset | |
yield (path, item[()]) | |
elif isinstance(item, h5py.Group): # test for group (go down) | |
yield from h5py_dataset_iterator(item, path) | |
for path, item in h5py_dataset_iterator(hdf_file): | |
yield ("/".join(path.split("/")[2:]), item) | |
npy = [(a, b) for a, b in traverse_datasets(f1)] | |
def get_weights_by_keys(weights_for_layer, keys): | |
return [weights_for_layer[k] for k in keys] | |
model = keras_cv.models.EfficientNetV2S(include_top=True, classes=1000, include_rescaling=False) | |
total_copied_weights = 0 | |
for layer in model.layers: | |
weights_for_layer = {name:weights for name, weights in npy if name.startswith(layer.name)} | |
num_weights = len(weights_for_layer) | |
if num_weights == 0: | |
continue | |
elif num_weights == 1: | |
layer.set_weights(list(weights_for_layer.values())) | |
elif num_weights == 2: | |
# This is a dense layer | |
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}/kernel:0", f"{layer.name}/bias:0"])) | |
elif num_weights == 4: | |
# This is a batchnorm layer | |
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}/gamma:0", f"{layer.name}/beta:0", f"{layer.name}/moving_mean:0", f"{layer.name}/moving_variance:0"])) | |
elif num_weights == 5: | |
# This is a FusedMBConvBlock layer with no squeeze excite, no expansion phase | |
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"])) | |
elif num_weights == 10: | |
# This is a FusedMBConvBlock layer with no squeeze excite | |
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}expand_conv/kernel:0", f"{layer.name}expand_bn/gamma:0", f"{layer.name}expand_bn/beta:0", f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}expand_bn/moving_mean:0", f"{layer.name}expand_bn/moving_variance:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"])) | |
elif num_weights == 19: | |
# This is a MBConvBlock layer with everything | |
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}expand_conv/kernel:0", f"{layer.name}expand_bn/gamma:0", f"{layer.name}expand_bn/beta:0", f"{layer.name}dwconv2/depthwise_kernel:0", f"{layer.name}bn/gamma:0", f"{layer.name}bn/beta:0", f"{layer.name}se_reduce/kernel:0", f"{layer.name}se_reduce/bias:0", f"{layer.name}se_expand/kernel:0", f"{layer.name}se_expand/bias:0", f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}expand_bn/moving_mean:0", f"{layer.name}expand_bn/moving_variance:0", f"{layer.name}bn/moving_mean:0", f"{layer.name}bn/moving_variance:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"])) | |
else: | |
print(num_weights) | |
print(weights_for_layer.keys()) | |
print(type(layer)) | |
print(layer.name) | |
raise NotImplementedError() | |
total_copied_weights += num_weights | |
print(total_copied_weights) | |
model.save_weights("converted.h5") | |
old_weights = old_model.trainable_weights | |
new_weights = model.trainable_weights | |
for index in range(len(old_weights)): | |
if not tf.reduce_all(old_weights[index] == new_weights[index]): | |
print(index) | |
print(f"Outputs match: {tf.reduce_all(old_model(img) == model(img)).numpy()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment