Skip to content

Instantly share code, notes, and snippets.

@ianstenbit
Created December 15, 2022 23:52
Show Gist options
  • Save ianstenbit/5998f993bdea735bce672fb5b0c57ac9 to your computer and use it in GitHub Desktop.
Save ianstenbit/5998f993bdea735bce672fb5b0c57ac9 to your computer and use it in GitHub Desktop.
Converting EfficientNet Weights
# Loaded with python -i ../path/to/old/efficientnet/impl
old_model = EfficientNetV2S(include_top=True, classes=1000, include_rescaling=False, weights="./classification-v0.h5")
img = tf.random.normal((1, 512, 512, 3))
import keras_cv
import numpy
import h5py
weights = "classification-v0.h5"
f1 = h5py.File(weights,'r+')
def traverse_datasets(hdf_file):
def h5py_dataset_iterator(g, prefix=''):
for key in g.keys():
item = g[key]
path = f'{prefix}/{key}'
if isinstance(item, h5py.Dataset): # test for dataset
yield (path, item[()])
elif isinstance(item, h5py.Group): # test for group (go down)
yield from h5py_dataset_iterator(item, path)
for path, item in h5py_dataset_iterator(hdf_file):
yield ("/".join(path.split("/")[2:]), item)
npy = [(a, b) for a, b in traverse_datasets(f1)]
def get_weights_by_keys(weights_for_layer, keys):
return [weights_for_layer[k] for k in keys]
model = keras_cv.models.EfficientNetV2S(include_top=True, classes=1000, include_rescaling=False)
total_copied_weights = 0
for layer in model.layers:
weights_for_layer = {name:weights for name, weights in npy if name.startswith(layer.name)}
num_weights = len(weights_for_layer)
if num_weights == 0:
continue
elif num_weights == 1:
layer.set_weights(list(weights_for_layer.values()))
elif num_weights == 2:
# This is a dense layer
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}/kernel:0", f"{layer.name}/bias:0"]))
elif num_weights == 4:
# This is a batchnorm layer
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}/gamma:0", f"{layer.name}/beta:0", f"{layer.name}/moving_mean:0", f"{layer.name}/moving_variance:0"]))
elif num_weights == 5:
# This is a FusedMBConvBlock layer with no squeeze excite, no expansion phase
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"]))
elif num_weights == 10:
# This is a FusedMBConvBlock layer with no squeeze excite
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}expand_conv/kernel:0", f"{layer.name}expand_bn/gamma:0", f"{layer.name}expand_bn/beta:0", f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}expand_bn/moving_mean:0", f"{layer.name}expand_bn/moving_variance:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"]))
elif num_weights == 19:
# This is a MBConvBlock layer with everything
layer.set_weights(get_weights_by_keys(weights_for_layer, [f"{layer.name}expand_conv/kernel:0", f"{layer.name}expand_bn/gamma:0", f"{layer.name}expand_bn/beta:0", f"{layer.name}dwconv2/depthwise_kernel:0", f"{layer.name}bn/gamma:0", f"{layer.name}bn/beta:0", f"{layer.name}se_reduce/kernel:0", f"{layer.name}se_reduce/bias:0", f"{layer.name}se_expand/kernel:0", f"{layer.name}se_expand/bias:0", f"{layer.name}project_conv/kernel:0", f"{layer.name}project_bn/gamma:0", f"{layer.name}project_bn/beta:0", f"{layer.name}expand_bn/moving_mean:0", f"{layer.name}expand_bn/moving_variance:0", f"{layer.name}bn/moving_mean:0", f"{layer.name}bn/moving_variance:0", f"{layer.name}project_bn/moving_mean:0", f"{layer.name}project_bn/moving_variance:0"]))
else:
print(num_weights)
print(weights_for_layer.keys())
print(type(layer))
print(layer.name)
raise NotImplementedError()
total_copied_weights += num_weights
print(total_copied_weights)
model.save_weights("converted.h5")
old_weights = old_model.trainable_weights
new_weights = model.trainable_weights
for index in range(len(old_weights)):
if not tf.reduce_all(old_weights[index] == new_weights[index]):
print(index)
print(f"Outputs match: {tf.reduce_all(old_model(img) == model(img)).numpy()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment