Skip to content

Instantly share code, notes, and snippets.

@jmduarte
Created February 26, 2018 18:05
Show Gist options
  • Save jmduarte/64edd2652286482ef151bd22b9480890 to your computer and use it in GitHub Desktop.
Save jmduarte/64edd2652286482ef151bd22b9480890 to your computer and use it in GitHub Desktop.
model to json
import sys
import os
import keras
from keras.models import load_model
from optparse import OptionParser
import numpy as np
from keras import backend as K
def print_model_to_json(keras_model, outfile_name):
outfile = open(outfile_name,'wb')
jsonString = keras_model.to_json()
print jsonString
import json
with outfile:
obj = json.loads(jsonString)
json.dump(obj, outfile, sort_keys=True,indent=4, separators=(',', ': '))
outfile.write('\n')
def huber_loss(y_true, y_pred, clip_value=1e-07):
# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
# for details.
assert clip_value > 0.
x = y_true - y_pred
if np.isinf(clip_value):
# Spacial case for infinity since Tensorflow does have problems
# if we compare `K.abs(x) < np.inf`.
return .5 * K.square(x)
condition = K.abs(x) < clip_value
squared_loss = .5 * K.square(x)
linear_loss = clip_value * (K.abs(x) - .5 * clip_value)
if K.backend() == 'tensorflow':
import tensorflow as tf
if hasattr(tf, 'select'):
return tf.select(condition, squared_loss, linear_loss) # condition, true, false
else:
return tf.where(condition, squared_loss, linear_loss) # condition, true, false
elif K.backend() == 'theano':
from theano import tensor as T
return T.switch(condition, squared_loss, linear_loss)
else:
raise RuntimeError('Unknown backend "{}".'.format(K.backend()))
from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({"huber_loss": huber_loss})
if __name__ == "__main__":
parser = OptionParser()
parser.add_option('-m','--model' ,action='store',type='string',dest='inputModel' ,default='train_simple/KERAS_check_best_model.h5', help='input model')
(options,args) = parser.parse_args()
model = load_model(options.inputModel)
print_model_to_json(model,options.inputModel.replace('.h5','.json'))#, custom_objects={'huber_loss': huber_loss})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment