Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@kemingy
Last active January 12, 2023 04:44
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save kemingy/a382528b29f6e34c47b464cf16806731 to your computer and use it in GitHub Desktop.
Save kemingy/a382528b29f6e34c47b464cf16806731 to your computer and use it in GitHub Desktop.
Tensorflow Serving, TensorRT Inference Server (Triton), Multi Model Server (MXNet)

Environments

  • CPU: Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz
  • GPU: NVIDIA V100
  • Memory: 251GiB
  • OS: Ubuntu 16.04.6 LTS (Xenial Xerus)

Docker Images:

  • tensorflow/tensorflow:latest-gpu
  • tensorflow/serving:latest-gpu
  • nvcr.io/nvidia/tensorrtserver:19.10-py3
Framework Model Model Type Images Batch size Time(s)
Tensorflow ResNet50 TF Savedmodel 32000 32 83.189
Tensorflow ResNet50 TF Savedmodel 32000 10 86.897
Tensorflow Serving ResNet50 TF Savedmodel 32000 32 120.496
Tensorflow Serving ResNet50 TF Savedmodel 32000 10 116.887
Triton (TensorRT Inference Server) ResNet50 TF Savedmodel 32000 32 201.855
Triton (TensorRT Inference Server) ResNet50 TF Savedmodel 32000 10 171.056
Falcon + msgpack + Tensorflow ResNet50 TF Savedmodel 32000 32 115.686
Falcon + msgpack + Tensorflow ResNet50 TF Savedmodel 32000 10 115.572
import falcon
from wsgiref import simple_server
import tensorflow as tf
import numpy as np
from time import time
from falcon import media
# setup msgpack handler in falcon
handlers = media.Handlers({
'application/msgpack': media.MessagePackHandler(),
})
api = falcon.API(media_type='application/msgpack')
api.req_options.media_handlers = handlers
api.resp_options.media_handlers = handlers
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
# tf.config.experimental.set_memory_growth(gpus[0], True)
class Inference():
def __init__(self):
self.model = tf.saved_model.load('./')
def on_get(self, req, resp):
resp.media = {'msg': 'working'}
def on_post(self, req, resp):
batch = req.media['batch']
data = np.frombuffer(req.media['data'], dtype=np.float32).reshape((batch, 224, 224, 3))
t0 = time()
result = self.model(inputs=data).numpy()
resp.media = {
'result': result.tobytes(),
}
api.add_route('/api', Inference())
httpd = simple_server.make_server('0.0.0.0', 8000, api)
httpd.serve_forever()
import requests
from time import time
import numpy as np
import msgpack
from tqdm import tqdm
if __name__ == "__main__":
url = 'http://localhost:8000/api'
session = requests.Session()
session.headers.update({'Content-Type': 'application/msgpack'})
packer = msgpack.Packer(
autoreset=True,
use_bin_type=True,
)
batch = 10
epoch = 3200
results = []
print(msgpack.unpackb(session.get(url).content))
resp = session.post(
url,
data=packer.pack({
'data': np.random.random((batch, 224, 224, 3)).astype(np.float32).tobytes(),
'batch': batch,
})
)
print(msgpack.unpackb(resp.content))
t0 = time()
for i in tqdm(range(epoch)):
resp = session.post(
url,
data=packer.pack({
'data': np.random.random((batch, 224, 224, 3)).astype(np.float32).tobytes(),
'batch': batch,
})
)
results.append(msgpack.unpackb(resp.content)['result'])
print(np.frombuffer(results[-1], dtype=np.float32))
print(f'{epoch} epochs with batch {batch} ({epoch * batch}) in {time() - t0}')

Model: ResNet50

  • GPU: V100
  • Test multiple models on one GPU card processing 32000 images (batch 10, epoch 3200)
Framework GPU RAM GPU Utilization Time
PyTorch 2G 100% 35s
PyTorch x 2 4G 100% 76s
TensorFlow 8G 52% 81s
TensorFlow x 2 16G 67-97% 92s
ONNX Runtime GPU 1G 52% 72s
ONNX Runtime GPU x 2 2G 68-98% 103s

Model: SeResNeXt50

  • GPU: V100
  • Python: 3.8
Framework Batch Size Epoch GPU RAM GPU Utilization Time
PyTorch (1.6.0) 1 32000 1.4G 61% 406s
TensorFlow (2.3.0) 1 32000 8G 70% 860s
ONNX Runtime GPU (1.4.0) 1 32000 0.9G 77% 264s
PyTorch (1.6.0) 10 3200 2.6G 93% 60s
TensorFlow (2.3.0) 10 3200 8G 70% 153s
ONNX Runtime GPU (1.4.0) 10 3200 1.2G 60% 97s
PyTorch (1.6.0) 32 1000 5.6G 97% 48s
TensorFlow (2.3.0) 32 1000 9G 70% 114s
ONNX Runtime GPU (1.4.0) 32 1000 1.9G 57% 91s

Tensorflow Serving

https://www.tensorflow.org/tfx/serving

  • coupled with Tensorflow ecosystem (also support other format, not out-of-box)
  • A/B testing
  • provide both gRPC and HTTP RESTful API
  • prometheus integration
  • batching
  • multiple models
  • preprocessing & postprocessing can be implemented with signatures

Triton Inference Server

https://github.com/NVIDIA/triton-inference-server/

  • support multiply backends: ONNX, PyTorch, TensorFlow, Caffe2, TensorRT
  • both gRPC and HTTP with SDK
  • internal health check and prometheus metrics
  • batching
  • concurrent model execution
  • preprocessing & postprocessing can be done with ensemble models
  • shm-size, memlock, stack configurations are not available for Kubernetes

Multi Model Server

https://github.com/awslabs/multi-model-server

  • require Java 8
  • provide HTTP
  • Java layer communicates with Python workers through Unix Domain Socket or TCP
  • batching
  • multiple models
  • log4j
  • management API
  • need to write model loading and inference code (means can use any runtime you want)
  • easy to add preprocessing and postprocessing to the service

Torch Serve

https://github.com/pytorch/serve

  • forked from AWS Multi Model Server
  • only for PyTorch models
  • doesn't have specific optimization for PyTorch models

GraphPipe

https://oracle.github.io/graphpipe

  • use flatbuffer which is more efficient
  • 2 years ago...
import tensorflow as tf
import numpy as np
from time import time
from tqdm import tqdm
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
# tf.config.experimental.set_memory_growth(gpus[0], True)
if __name__ == "__main__":
# model
model = tf.saved_model.load('/tmp/resnet50_tf/1/')
# init
epoch = 3200
batch = 10
data = np.random.random((batch, 224, 224, 3)).astype(np.float32)
print(model(inputs=data).numpy())
# bench
t0 = time()
results = []
for _ in tqdm(range(epoch)):
data = np.random.random((batch, 224, 224, 3)).astype(np.float32)
result = model(inputs=data).numpy()
results.append(result)
print(f'{epoch} epochs with batch {batch} ({batch * epoch}) in {time() - t0} seconds')
import grpc
import tensorflow as tf
import numpy as np
from time import time
from tqdm import tqdm
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
if __name__ == "__main__":
channel = grpc.insecure_channel('localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# Send request
# See prediction_service.proto for gRPC request/response details.
request = predict_pb2.PredictRequest()
request.model_spec.name = 'resnet50_tf'
request.model_spec.signature_name = 'serving_default'
batch = 10
epoch = 3200
results = []
t0 = time()
for _ in tqdm(range(epoch)):
data = np.random.random((batch, 224, 224, 3)).astype(np.float32)
request.inputs['input_1'].CopyFrom(
tf.make_tensor_proto(data, shape=[batch, 224, 224, 3]))
result = stub.Predict(request, 10.0) # 10 secs timeout
results.append(result)
print(f'run {epoch} epochs with batch {batch} ({batch * epoch}) in {time() - t0} s')
from time import time
from tqdm import tqdm
import numpy as np
import grpc
from tensorrtserver.api import api_pb2
from tensorrtserver.api import grpc_service_pb2
from tensorrtserver.api import grpc_service_pb2_grpc
if __name__ == '__main__':
batch = 32
channel = grpc.insecure_channel('localhost:8001')
stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)
req = grpc_service_pb2.InferRequest()
req.model_name = 'resnet50'
req.model_version = -1
req.meta_data.batch_size = batch
output = api_pb2.InferRequestHeader.Output()
output.name = 'predictions'
output.cls.count = 1000
req.meta_data.output.append(output)
req.meta_data.input.add(name='input_1')
# init
del req.raw_input[:]
data = b''
for _ in range(batch):
data += np.random.random((224, 224, 3)).astype(np.float32).tobytes()
req.raw_input.append(data)
print(stub.Infer(req))
# bench
t0 = time()
epoch = 1000
results = []
for _ in tqdm(range(epoch)):
del req.raw_input[:]
data = b''
for b in range(batch):
data += np.random.random((224, 224, 3)).astype(np.float32).tobytes()
req.raw_input.append(data)
results.append(stub.Infer(req))
print(f'{epoch} epochs with batch {batch} ({batch * epoch}) in {time() - t0} seconds')
@kemingy
Copy link
Author

kemingy commented Nov 7, 2020

I recommend using https://github.com/triton-inference-server/server/blob/master/docs/perf_analyzer.md to continue your study, and put the result csv following the instruction here https://docs.nvidia.com/deeplearning/triton-inference-server/master-user-guide/docs/optimization.html#visualizing-latency-vs-throughput to understand the triton performance.

Thanks for your advice. I'll try this profiling tool.

Recently I did some perf analysis on Triton with TRT optimization, we are looking at RN50 about 3k images/s at ~50ms latency.

That's impressive. May I know your test environments? [GPU type + num, quantization?, distillation?]

@mengdong
Copy link

mengdong commented Nov 7, 2020

1 T4, TRT INT8, no distillation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment