Skip to content

Instantly share code, notes, and snippets.

@rpicatoste
Last active October 28, 2022 09:40
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rpicatoste/02cecac1ed52524301e3ab423dac888b to your computer and use it in GitHub Desktop.
Save rpicatoste/02cecac1ed52524301e3ab423dac888b to your computer and use it in GitHub Desktop.
Function to convert a Keras LSTM model trained as stateless to a stateful model expecting a single sample and time step as input to use in inference.
import json
from keras.models import model_from_json
def convert_to_inference_model(original_model):
original_model_json = original_model.to_json()
inference_model_dict = json.loads(original_model_json)
layers = inference_model_dict['config']
for layer in layers:
if 'stateful' in layer['config']:
layer['config']['stateful'] = True
if 'batch_input_shape' in layer['config']:
layer['config']['batch_input_shape'][0] = 1
layer['config']['batch_input_shape'][1] = None
inference_model = model_from_json(json.dumps(inference_model_dict))
inference_model.set_weights(original_model.get_weights())
return inference_model
@emuccino
Copy link

emuccino commented Feb 19, 2019

Should line 8 instead be:
layers = inference_model_dict['config']['layers']

@rpicatoste
Copy link
Author

rpicatoste commented Feb 26, 2019

It looks like we have different things, maybe?
In inference_model_dict['config'] I have the list of layers (the config of each layer actually):

type(inference_model_dict['config'])
Out[4]: list
inference_model_dict['config'][0]
Out[5]: 
{'class_name': 'LSTM',
 'config': {'name': 'lstm_8',
  'trainable': True,
  'batch_input_shape': [None, 450, 7],
  'dtype': 'float32',
  ...
  'implementation': 1}}

I am using keras 2.1.5, could it be that this structure has changed? However such a change would be a bit strange.

@VertexC
Copy link

VertexC commented Jun 8, 2019

I use the same way to make the stateless model to stateful, while it doesn't change its behaviour while testing.

@jbonyun
Copy link

jbonyun commented Feb 9, 2021

Works in keras 2.3.0-tf if you change:

layers = inference_model_dict['config']

to

layers = inference_model_dict['config']['layers']

I also added a change to a Reshape layer, which was reshaping to the number of timesteps I used in training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment