Skip to content

Instantly share code, notes, and snippets.

@davecg
Forked from mgraczyk/fix_keras_model.py
Last active April 9, 2017 00:00
Show Gist options
  • Save davecg/396a65abde32590fbb43c15951de41f4 to your computer and use it in GitHub Desktop.
Save davecg/396a65abde32590fbb43c15951de41f4 to your computer and use it in GitHub Desktop.
Fix input_dtype errors in pre-2.0 keras model H5FS files
import h5py
import shutil
import json
import sys
def fix_weight_file(model_path):
with h5py.File(model_path, "r+") as out_h5:
v = out_h5.attrs.get("model_config")
config = json.loads(v)
for i, l in enumerate(config["config"]["layers"]):
dtype = l["config"].pop("input_dtype", None)
if dtype is not None:
l["config"]["dtype"] = dtype
new_config_str = json.dumps(config)
out_h5.attrs.modify("model_config", new_config_str)
if __name__ == '__main__':
# Usage: fix_keras_model.py old_model.h5 new_model.h5
input_model_path = sys.argv[1]
output_model_path = sys.argv[2]
shutil.copyfile(input_model_path, output_model_path)
fix_weight_file(output_model_path)
# Check that it worked.
from keras.models import load_model
load_model(output_model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment