Skip to content

Instantly share code, notes, and snippets.

@qubvel
Last active June 14, 2019 21:39
Show Gist options
  • Save qubvel/8a6d23e485ebb1611b330bbfa534a6ac to your computer and use it in GitHub Desktop.
Save qubvel/8a6d23e485ebb1611b330bbfa534a6ac to your computer and use it in GitHub Desktop.
Code example for converting TF weights to Keras
import pickle
from keras.layers import Conv2D, BatchNormalization, Dense
# NOTE!
# It is supposed to be used with python 3.6+ as it is rely on ordered keys of dict
def get_name(name):
"""Parse name"""
parts = name.split('/')[:-1]
return '/'.join(parts)
def group_weights(weights):
"""
Group each layer weights together, initially all weights are dict of 'layer_name/layer_var': np.array
Example:
input: {
...: ...
'conv2d/kernel': <np.array>,
'conv2d/bias': <np.array>,
...: ...
}
output: [..., [...], [<conv2d/kernel-weights>, <conv2d/bias-weights>], [...], ...]
"""
out_weights = []
previous_layer_name = ''
group = []
for k, v in weights.items():
layer_name = get_name(k)
if layer_name == previous_layer_name:
group.append(v)
else:
if group:
out_weights.append(group)
group = [v]
previous_layer_name = layer_name
out_weights.append(group)
return out_weights
def load_weights(model, weights):
"""Load weights to Conv2D, BatchNorm, Dense layers of model sequentially"""
i = 0
for layer in model.layers:
if isinstance(layer, (Conv2D, BatchNormalization, Dense)):
print(layer)
layer.set_weights(groupped_weights[i])
i += 1
# read saved TF model weights
with open('../../checkpoints/{}/weights.pkl'.format(model_name), 'rb') as f:
weights = pickle.load(f)
# convert weights to keras format
groupped_weights = group_weights(weights)
# create model same as tf
model = EfficientNetB0(...)
# load weights layer by layer
load_weights(model, groupped_weights)
@qubvel
Copy link
Author

qubvel commented Jun 14, 2019

Code for dumping TF model weights (paste after checkpoint loading)

import os
import pickle

vars_global = tf.global_variables()

model_vars = {}
for var in vars_global:
  try:
    model_vars[var.name] = var.eval()
  except:
    print("For var={}, an exception occurred".format(var.name))

with open(os.path.join(ckpt_dir, 'weights.pkl'), 'wb') as f:
  pickle.dump(model_vars, f)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment