Last active
February 14, 2020 06:32
-
-
Save hanneshapke/09db574f0d02623552f216c08e4336b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import googleapiclient.discovery | |
from example_pb2 import Example | |
from feature_pb2 import BytesList, Feature, Features | |
def _convert_to_pb(value): | |
""" Serialize a given sentence to the ProtoBuf Structure required to model the tf.Example data structure. | |
Feel free to add more features and different data types if your models reqiures different inputs. An overview of | |
all available data types can be here: https://www.tensorflow.org/alpha/tutorials/load_data/tf_records | |
""" | |
example = Example(features=Features(feature={ | |
'sentence': Feature(bytes_list=BytesList(value=[str.encode(value)])) | |
})) | |
return example.SerializeToString() | |
def _generate_payload(sentence): | |
""" Assemble the payload for the model inference. If the protobuf format is used, the model server instance | |
(GCP ML Engine or your own TensorFlow Serving instance) expect a directionary with a key `instances` and a list of | |
example dictionaries as value. | |
""" | |
encoded_payload = _convert_to_pb(sentence) | |
return {'instances': [ | |
{'examples': {"b64": base64.b64encode(encoded_payload).decode()}}]} | |
def _get_model_prediction(service, project, model='demo_model', body=None): | |
""" Infers the model server for a prediction | |
""" | |
if body is None: | |
raise NotImplementedError(f"_get_model_prediction didn't get any payload for model {model}") | |
response = service.projects().predict( | |
name='projects/{}/models/{}'.format(project, model), | |
body=body | |
).execute() | |
return response | |
def _connect_service(): | |
""" creates a GCP service API object. If you use Google service account credentials, here is the place to load them | |
and add them to your build keyword arguments. | |
""" | |
kwargs = {'serviceName': 'ml', 'version': 'v1'} | |
return googleapiclient.discovery.build(**kwargs) | |
# An example prediction | |
service = _connect_service() | |
sentence = "Classify me" | |
body = _generate_payload(sentence) | |
prediction = _get_model_prediction(service, project='my_project', model='demo_model', body=body) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment