Last active
December 20, 2017 19:00
-
-
Save zmjjmz/ce9c7a896933a02953cae0069a2ca21e to your computer and use it in GitHub Desktop.
TF weight discrepancy repro code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import os | |
import sys | |
import subprocess | |
import time | |
import numpy as np | |
import tensorflow | |
from tensorflow.python.tools import inspect_checkpoint | |
from tensorflow.python.ops import ( | |
control_flow_ops, | |
variables, | |
lookup_ops | |
) | |
from grpc.beta import implementations | |
from tensorflow_serving.apis import prediction_service_pb2 | |
from tensorflow_serving.apis import predict_pb2 | |
from tensorflow.python.framework import tensor_util | |
import keras | |
def export_keras_model_simple(model, export_dir, version, conflict_policy='fail'): | |
# only supports one input / one output models | |
# version is supposed to be an uint! | |
version_str = str(version) | |
# for anything that has to change behavior here | |
keras.backend.set_learning_phase(0) | |
final_export_dir = os.path.join(export_dir, version_str) | |
# TODO (ZJ) move this outside of here | |
if os.path.exists(final_export_dir): | |
print("Export dir exists! Resolving according to conflict policy: {0}".format( | |
conflict_policy)) | |
if conflict_policy == 'bump': | |
import glob | |
existing_versions = map(lambda x: x.split( | |
'/')[-1], glob.glob(os.path.join(export_dir, '*'))) | |
viable_versions = [] | |
for v in existing_versions: | |
try: | |
viable_versions.append(int(v)) | |
except ValueError: | |
continue | |
new_version = max(viable_versions) + 1 | |
print("Bumping to version {0}".format(new_version)) | |
final_export_dir = os.path.join(export_dir, str(new_version)) | |
if conflict_policy == 'replace': | |
# just remove the current export dir and replace it. not a great idea! | |
# doesn't create a backup, so if for some reason the rest of this fails, | |
# it will just remove that directory... | |
import shutil | |
shutil.rmtree(final_export_dir) | |
if conflict_policy == 'fail': | |
# raise an error | |
raise ValueError( | |
"Directory {0} already exists, please specify a new version!".format(final_export_dir)) | |
builder = tensorflow.saved_model.builder.SavedModelBuilder( | |
final_export_dir) | |
input_map = {il.name:il_t for il, il_t in zip(model.input_layers, model.inputs)} | |
output_map = {ol.name:ol_t for ol, ol_t in zip(model.output_layers, model.outputs)} | |
print(input_map) | |
print(output_map) | |
signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( | |
inputs=input_map, | |
outputs=output_map) | |
main_op_new = control_flow_ops.group( | |
lookup_ops.tables_initializer(), | |
variables.local_variables_initializer(), | |
#variables.global_variables_initializer() | |
) | |
sess = keras.backend.get_session() | |
builder.add_meta_graph_and_variables( | |
sess=sess, | |
tags=[tensorflow.saved_model.tag_constants.SERVING], | |
signature_def_map={'predict': signature}, | |
main_op=tensorflow.saved_model.main_op.main_op() | |
#main_op=main_op_new | |
) | |
builder.save() | |
return final_export_dir | |
def simple_serving_client(inp_, host, port, input_name='lookedup', signature='predict', model_name='default'): | |
channel = implementations.insecure_channel(host, int(port)) | |
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = model_name | |
request.model_spec.signature_name = signature | |
request.inputs[input_name].CopyFrom( | |
tensorflow.make_tensor_proto( | |
inp_, | |
shape=inp_.shape, | |
dtype='int32')) | |
result = stub.Predict(request, 10) | |
return result | |
def keras_tf_minimal_repro(store_path): | |
embedding_mat = np.zeros(shape=(5,3)) | |
pad_length = 5 | |
data = np.stack([ | |
np.ones(pad_length, dtype=np.int32), | |
np.zeros(pad_length, dtype=np.int32)]) | |
inp = keras.layers.Input(shape=(pad_length,), name='lookedup', dtype='int32') | |
emb = keras.layers.Embedding(*(embedding_mat.shape), weights=[embedding_mat], | |
input_length=pad_length, name='embed', | |
trainable=False)(inp) | |
model = keras.models.Model(inputs=[inp], outputs=[emb]) | |
keras_out = model.predict(data) | |
print("output from keras") | |
print(keras_out) | |
actual_dir = export_keras_model_simple(model, store_path, 0, conflict_policy='bump') | |
time.sleep(5) # give TF serving some time to load the new version | |
# Should be all zeros | |
print("model weights") | |
inspect_checkpoint.print_tensors_in_checkpoint_file(os.path.join(actual_dir, 'variables/variables'), | |
"", True) | |
res_proto = simple_serving_client(data, 'localhost', 5566) | |
proto_val = tensor_util.MakeNdarray(res_proto.outputs['embed']) | |
print("tf serving output") | |
print(proto_val) | |
model_output = model.predict([data]) | |
print("keras output") | |
print(model_output) | |
if __name__ == "__main__": | |
try: | |
store_path = sys.argv[1] | |
except IndexError: | |
print("Please provide a path to store the exported model") | |
sys.exit(1) | |
# BEFORE RUNNING THIS | |
# run tensorflow_model_server --model_base_path=<store_path> --port=5566 | |
keras_tf_minimal_repro(store_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment