Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malcolmgreaves/6f15a4f0b22eb2217913d6d711552662 to your computer and use it in GitHub Desktop.
Save malcolmgreaves/6f15a4f0b22eb2217913d6d711552662 to your computer and use it in GitHub Desktop.
Example showing the a multi-input node Keras model that's trained on toy data and saved as a TensorFlow Serving model artifact.
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import Input, Model
from tensorflow.python.keras.layers import Dense, concatenate
def single_feat_input_model(n_output_per: int) -> Model:
inputNode = Input(shape=(1,))
m = Model(
inputs=inputNode, outputs=Dense(n_output_per, activation="relu")(inputNode)
)
return m
def make_multi_single_feat(
n_input: int, n_output_per: int, n_output_final: int = 1
) -> Model:
inputs = [single_feat_input_model(n_output_per) for _ in range(n_input)]
combined = concatenate(list(map(lambda i_layer: i_layer.output, inputs)))
final_layer = Dense(n_output_final, activation="softmax")(combined)
model = Model(
inputs=list(map(lambda i_layer: i_layer.input, inputs)), outputs=final_layer
)
return model
def remove(path: str) -> None:
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
os.rmdir(path)
if __name__ == "__main__":
export_path = sys.argv[1] if len(sys.argv) > 1 else "./model_export"
remove(export_path)
N_INPUT = 5
X = np.array(
[
[[1], [2], [3], [4], [5]],
[[3], [1], [5], [17], [18]],
[[0], [0], [0], [0], [0]],
],
dtype=float,
)
X_multi_input_examples = [X[:, i, :] for i in range(N_INPUT)]
Y = np.array([1, 1, 0], dtype=int)
model = make_multi_single_feat(n_input=N_INPUT, n_output_per=3, n_output_final=1)
model.compile(optimizer="adadelta", loss="binary_crossentropy")
history = model.fit(X_multi_input_examples, Y)
with tf.keras.backend.get_session() as sess:
tf.saved_model.simple_save(
sess,
export_path,
inputs={i.name: i for i in model.input},
outputs={t.name: t for t in model.outputs},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment