Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Last active December 20, 2017 19:00
Show Gist options
  • Save zmjjmz/ce9c7a896933a02953cae0069a2ca21e to your computer and use it in GitHub Desktop.
Save zmjjmz/ce9c7a896933a02953cae0069a2ca21e to your computer and use it in GitHub Desktop.
TF weight discrepancy repro code
from __future__ import division
import os
import sys
import subprocess
import time
import numpy as np
import tensorflow
from tensorflow.python.tools import inspect_checkpoint
from tensorflow.python.ops import (
control_flow_ops,
variables,
lookup_ops
)
from grpc.beta import implementations
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow.python.framework import tensor_util
import keras
def export_keras_model_simple(model, export_dir, version, conflict_policy='fail'):
# only supports one input / one output models
# version is supposed to be an uint!
version_str = str(version)
# for anything that has to change behavior here
keras.backend.set_learning_phase(0)
final_export_dir = os.path.join(export_dir, version_str)
# TODO (ZJ) move this outside of here
if os.path.exists(final_export_dir):
print("Export dir exists! Resolving according to conflict policy: {0}".format(
conflict_policy))
if conflict_policy == 'bump':
import glob
existing_versions = map(lambda x: x.split(
'/')[-1], glob.glob(os.path.join(export_dir, '*')))
viable_versions = []
for v in existing_versions:
try:
viable_versions.append(int(v))
except ValueError:
continue
new_version = max(viable_versions) + 1
print("Bumping to version {0}".format(new_version))
final_export_dir = os.path.join(export_dir, str(new_version))
if conflict_policy == 'replace':
# just remove the current export dir and replace it. not a great idea!
# doesn't create a backup, so if for some reason the rest of this fails,
# it will just remove that directory...
import shutil
shutil.rmtree(final_export_dir)
if conflict_policy == 'fail':
# raise an error
raise ValueError(
"Directory {0} already exists, please specify a new version!".format(final_export_dir))
builder = tensorflow.saved_model.builder.SavedModelBuilder(
final_export_dir)
input_map = {il.name:il_t for il, il_t in zip(model.input_layers, model.inputs)}
output_map = {ol.name:ol_t for ol, ol_t in zip(model.output_layers, model.outputs)}
print(input_map)
print(output_map)
signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
inputs=input_map,
outputs=output_map)
main_op_new = control_flow_ops.group(
lookup_ops.tables_initializer(),
variables.local_variables_initializer(),
#variables.global_variables_initializer()
)
sess = keras.backend.get_session()
builder.add_meta_graph_and_variables(
sess=sess,
tags=[tensorflow.saved_model.tag_constants.SERVING],
signature_def_map={'predict': signature},
main_op=tensorflow.saved_model.main_op.main_op()
#main_op=main_op_new
)
builder.save()
return final_export_dir
def simple_serving_client(inp_, host, port, input_name='lookedup', signature='predict', model_name='default'):
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = signature
request.inputs[input_name].CopyFrom(
tensorflow.make_tensor_proto(
inp_,
shape=inp_.shape,
dtype='int32'))
result = stub.Predict(request, 10)
return result
def keras_tf_minimal_repro(store_path):
embedding_mat = np.zeros(shape=(5,3))
pad_length = 5
data = np.stack([
np.ones(pad_length, dtype=np.int32),
np.zeros(pad_length, dtype=np.int32)])
inp = keras.layers.Input(shape=(pad_length,), name='lookedup', dtype='int32')
emb = keras.layers.Embedding(*(embedding_mat.shape), weights=[embedding_mat],
input_length=pad_length, name='embed',
trainable=False)(inp)
model = keras.models.Model(inputs=[inp], outputs=[emb])
keras_out = model.predict(data)
print("output from keras")
print(keras_out)
actual_dir = export_keras_model_simple(model, store_path, 0, conflict_policy='bump')
time.sleep(5) # give TF serving some time to load the new version
# Should be all zeros
print("model weights")
inspect_checkpoint.print_tensors_in_checkpoint_file(os.path.join(actual_dir, 'variables/variables'),
"", True)
res_proto = simple_serving_client(data, 'localhost', 5566)
proto_val = tensor_util.MakeNdarray(res_proto.outputs['embed'])
print("tf serving output")
print(proto_val)
model_output = model.predict([data])
print("keras output")
print(model_output)
if __name__ == "__main__":
try:
store_path = sys.argv[1]
except IndexError:
print("Please provide a path to store the exported model")
sys.exit(1)
# BEFORE RUNNING THIS
# run tensorflow_model_server --model_base_path=<store_path> --port=5566
keras_tf_minimal_repro(store_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment