Skip to content

Instantly share code, notes, and snippets.

@udibr
Last active May 9, 2020 23:12
Show Gist options
  • Save udibr/e522b44a1dc7d3a388d4386d416747f5 to your computer and use it in GitHub Desktop.
Save udibr/e522b44a1dc7d3a388d4386d416747f5 to your computer and use it in GitHub Desktop.
Load weights to Keras model from file allowing for differences between file and model
import numpy as np
import h5py
import keras.backend as K
def str_shape(x):
return 'x'.join(map(str, x.shape))
def load_weights(model, filepath, lookup={}, ignore=[], transform=None, verbose=True):
"""Modified version of keras load_weights that loads as much as it can.
Useful for transfer learning.
read the weights of layers stored in file and copy them to a model layer.
the name of each layer is used to match the file's layers with the model's.
It is possible to have layers in the model that dont appear in the file..
The loading stopps if a problem is encountered and the weights of the
file layer that first caused the problem are returned.
# Arguments
model: Model
target
filepath: str
source hdf5 file
lookup: dict (optional)
by default, the weights of each layer in the file are copied to the
layer with the same name in the model. Using lookup you can replace
the file name with a different model layer name, or to a list of
model layer names, in which case the same weights will be copied
to all layer models.
ignore: list (optional)
list of model layer names to ignore in
transform: None (optional)
This is an optional function that receives the list of weighs
read from a layer in the file and the model layer object to which
these weights should be loaded.
verbose: bool
high recommended to keep this true and to follow the print messages.
# Returns
weights of the file layer which first caused the load to abort or None
on successful load.
"""
if verbose:
print 'Loading', filepath, 'to', model.name
flattened_layers = model.layers
with h5py.File(filepath, mode='r') as f:
# new file format
layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
# we batch weight value assignments in a single backend call
# which provides a speedup in TensorFlow.
weight_value_tuples = []
for name in layer_names:
if verbose:
print name,
g = f[name]
weight_names = [n.decode('utf8') for n in
g.attrs['weight_names']]
if len(weight_names):
weight_values = [g[weight_name] for weight_name in
weight_names]
if verbose:
print 'loading', ' '.join(
str_shape(w) for w in weight_values),
target_names = lookup.get(name, name)
if isinstance(target_names, basestring):
target_names = [target_names]
# handle the case were lookup asks to send the same weight to multiple layers
target_names = [target_name for target_name in target_names if
target_name == name or target_name not in layer_names]
for target_name in target_names:
if verbose:
print target_name,
try:
layer = model.get_layer(name=target_name)
except:
layer = None
if layer:
# the same weight_values are copied to each of the target layers
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
if transform is not None:
transformed_weight_values = transform(weight_values, layer)
if transformed_weight_values is not None:
if verbose:
print '(%d->%d)'%(len(weight_values),len(transformed_weight_values)),
weight_values = transformed_weight_values
problem = len(symbolic_weights) != len(weight_values)
if problem and verbose:
print '(bad #wgts)',
if not problem:
weight_value_tuples += zip(symbolic_weights, weight_values)
else:
problem = True
if problem:
if verbose:
if name in ignore or ignore == '*':
print '(skipping)',
else:
print 'ABORT'
if not (name in ignore or ignore == '*'):
K.batch_set_value(weight_value_tuples)
return [np.array(w) for w in weight_values]
if verbose:
print
else:
if verbose:
print 'skipping this is empty file layer'
K.batch_set_value(weight_value_tuples)
@sinclairnick
Copy link

FYI: As of writing, this method does not work for tensorflow.keras @2.1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment