Skip to content

Instantly share code, notes, and snippets.

@mgraczyk
Last active August 29, 2019 07:46
Show Gist options
  • Save mgraczyk/269714a749dab895d176600d9c9441a5 to your computer and use it in GitHub Desktop.
Save mgraczyk/269714a749dab895d176600d9c9441a5 to your computer and use it in GitHub Desktop.
Fix input_dtype errors in pre-2.0 keras model H5FS files
# Usage: fix_keras_model.py old_model.h5 new_model.h5
import h5py
import shutil
import json
import sys
input_model_path = sys.argv[1]
output_model_path = sys.argv[2]
shutil.copyfile(input_model_path, output_model_path)
with h5py.File(output_model_path, "r+") as out_h5:
v = out_h5.attrs.get("model_config")
config = json.loads(v)
for i, l in enumerate(config["config"]["layers"]):
dtype = l["config"].pop("input_dtype", None)
if dtype is not None:
l["config"]["dtype"] = dtype
new_config_str = json.dumps(config)
out_h5.attrs.modify("model_config", new_config_str)
# Check that it worked.
from keras.models import load_model
load_model(output_model_path)
@DylanCope
Copy link

Doesn't work for Python 3 unfortunately :/

@Vladimir-Yashin
Copy link

To make it work with Python3:

  1. Replace
    v = out_h5.attrs.get("model_config")
    with
    v = out_h5.attrs.get("model_config").decode("utf-8")

  2. Replace
    new_config_str = json.dumps(config)
    with
    new_config_str = json.dumps(config).encode("utf-8")

@Zumbalamambo
Copy link

Zumbalamambo commented Aug 28, 2017

This is my code.

import h5py
import shutil
import json
import sys

input_model_path = 'age_model.h5'
output_model_path = 'modify_Age.model.h5'
shutil.copyfile(input_model_path, output_model_path)

with h5py.File(output_model_path, "r+") as out_h5:
    v = out_h5.attrs.get("model_config").decode("utf-8")
    config = json.loads(v)
    for i, l in enumerate(config["config"]["layers"]):
        dtype = l["config"].pop("input_dtype", None)
        if dtype is not None:
            l["config"]["dtype"] = dtype

    new_config_str = json.dumps(config).encode("utf-8")
    out_h5.attrs.modify("model_config", new_config_str)

# Check that it worked.
from keras.models import load_model
load_model(output_model

It fires following error,

for i, l in enumerate(config["config"]["layers"]):

TypeError: list indices must be integers or slices, not str

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment