import json
import grpc 

import numpy as np
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

crt_path = "/path/to/your/cert/server.crt"
host="your.elb.us-east-1.amazonaws.com"
port="443"

def get_grpc_connection(host, port=443, crt_path="server.crt"):
    """Establish secure grpc channel"""
    with open(crt_path, 'rb') as f:
        trusted_certs = f.read()
        
    credentials = grpc.ssl_channel_credentials(root_certificates=trusted_certs)
    channel = grpc.secure_channel('{}:{}'.format(host, port), credentials)
    return channel

def grpc_predict(channel, data, model_name, inputs_key, route, signature="serving_default", timeout_sec=5):
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    
    stub.Predict = channel.unary_unary(
        f'{route}/tensorflow.serving.PredictionService/Predict',
        request_serializer = predict_pb2.PredictRequest.SerializeToString,
        response_deserializer = predict_pb2.PredictResponse.FromString,
        )
    
    request = predict_pb2.PredictRequest()
    request.model_spec.name = model_name
    request.model_spec.signature_name = signature
    
    proto = tf.make_tensor_proto(np.array(data), dtype=float)
    request.inputs[inputs_key].CopyFrom(proto)
    result_future = stub.Predict.future(request, timeout_sec)

    response = np.array(result_future.result())
    return response 

if __name__=="__main__":
    channel = get_grpc_connection(host, port, crt_path)
    
    try:
        grpc.channel_ready_future(channel).result(timeout=10)
        print(":)")
    except:
        print(":(")
        
    data = [1.0, 2.0, 5.0]
    output = grpc_predict(channel, data, 'model', 'x', '/service1')
    print(output)