Skip to content

Instantly share code, notes, and snippets.

@hanneshapke
Last active February 14, 2020 06:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hanneshapke/09db574f0d02623552f216c08e4336b4 to your computer and use it in GitHub Desktop.
Save hanneshapke/09db574f0d02623552f216c08e4336b4 to your computer and use it in GitHub Desktop.
import base64
import googleapiclient.discovery
from example_pb2 import Example
from feature_pb2 import BytesList, Feature, Features
def _convert_to_pb(value):
""" Serialize a given sentence to the ProtoBuf Structure required to model the tf.Example data structure.
Feel free to add more features and different data types if your models reqiures different inputs. An overview of
all available data types can be here: https://www.tensorflow.org/alpha/tutorials/load_data/tf_records
"""
example = Example(features=Features(feature={
'sentence': Feature(bytes_list=BytesList(value=[str.encode(value)]))
}))
return example.SerializeToString()
def _generate_payload(sentence):
""" Assemble the payload for the model inference. If the protobuf format is used, the model server instance
(GCP ML Engine or your own TensorFlow Serving instance) expect a directionary with a key `instances` and a list of
example dictionaries as value.
"""
encoded_payload = _convert_to_pb(sentence)
return {'instances': [
{'examples': {"b64": base64.b64encode(encoded_payload).decode()}}]}
def _get_model_prediction(service, project, model='demo_model', body=None):
""" Infers the model server for a prediction
"""
if body is None:
raise NotImplementedError(f"_get_model_prediction didn't get any payload for model {model}")
response = service.projects().predict(
name='projects/{}/models/{}'.format(project, model),
body=body
).execute()
return response
def _connect_service():
""" creates a GCP service API object. If you use Google service account credentials, here is the place to load them
and add them to your build keyword arguments.
"""
kwargs = {'serviceName': 'ml', 'version': 'v1'}
return googleapiclient.discovery.build(**kwargs)
# An example prediction
service = _connect_service()
sentence = "Classify me"
body = _generate_payload(sentence)
prediction = _get_model_prediction(service, project='my_project', model='demo_model', body=body)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment