Last active
January 9, 2019 00:31
-
-
Save zhanwenchen/87e9d03cde64e0b641a9f72f6e05f0d4 to your computer and use it in GitHub Desktop.
A low-level API to connect with a running gRPC server and pass to it a TensorFlow-Serving prediction request
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adapted from https://github.com/Vetal1977/tf_serving_example/blob/master/svnh_semi_supervised_client.py | |
""" | |
# -*- coding: utf-8 -*- | |
import time | |
from argparse import ArgumentParser | |
import numpy as np | |
# Communication to TensorFlow server via gRPC | |
from grpc.beta import implementations | |
import tensorflow as tf | |
# TensorFlow serving stuff to send messages | |
from tensorflow_serving.apis import predict_pb2 | |
from tensorflow_serving.apis import prediction_service_pb2 | |
from tensorflow.contrib.util import make_tensor_proto | |
from os import listdir | |
from os.path import isfile, join | |
import helper | |
timeout = 60.0 | |
class Server: | |
def __init__(self, host, port): | |
# Channel and Stub are boiler-plate: | |
channel = implementations.insecure_channel(host, int(port)) | |
self.stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
# My own methods to load cached preprocessing results | |
(source_int_text, target_int_text), (source_vocab_to_int, target_vocab_to_int), (source_int_to_vocab, target_int_to_vocab) = helper.load_preprocess() | |
self.source_int_text = source_int_text | |
self.target_int_text = target_int_text | |
self.source_vocab_to_int = source_vocab_to_int | |
self.target_vocab_to_int = target_vocab_to_int | |
self.source_int_to_vocab = source_int_to_vocab | |
self.target_int_to_vocab = target_int_to_vocab | |
def translate(self, word): | |
batch_size = 32 # TODO: load this from model | |
print(word) | |
prediction_x = "\n".join([" ".join(line) for line in word.split("\n")]) | |
translate_sentence = helper.sentence_to_seq(prediction_x, self.source_vocab_to_int) | |
input_data = [translate_sentence]*batch_size # NOTE: Cannot be np because otherwise can't convert int64 to int32, etc | |
# Boiler-plate | |
request = predict_pb2.PredictRequest() | |
# Set request objects using the tf-serving `CopyFrom` setter method | |
request.model_spec.name = '0' # TODO: rename this to ru2ipa in saved_models | |
request.model_spec.signature_name = 'serving_default' # This is correct (default constant). | |
request.inputs['input'].CopyFrom(make_tensor_proto(input_data, shape=[batch_size, len(input_data[0])])) | |
request.inputs['keep_prob'].CopyFrom(make_tensor_proto(1.0)) | |
target_sequence_length = [len(translate_sentence)*2]*batch_size | |
request.inputs['target_sequence_length'].CopyFrom(make_tensor_proto(translate_sentence)) | |
source_sequence_length = [len(translate_sentence)]*batch_size | |
request.inputs['source_sequence_length'].CopyFrom(make_tensor_proto(source_sequence_length)) | |
# Boiler-Plate | |
response = self.stub.Predict(request, timeout) | |
result = response.outputs['predictions'] | |
result_array = tf.make_ndarray(result)[0] | |
x = "".join([self.source_int_to_vocab[i] for i in translate_sentence]) | |
y = "".join([self.target_int_to_vocab[i] for i in result_array]) | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment