Skip to content

Instantly share code, notes, and snippets.

@dhruvaray
Last active September 9, 2020 09:41
Show Gist options
  • Save dhruvaray/18c0d1d4f475b2072cc5bba45641e9e7 to your computer and use it in GitHub Desktop.
Save dhruvaray/18c0d1d4f475b2072cc5bba45641e9e7 to your computer and use it in GitHub Desktop.
dummymodelserving
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import time
import threading
import os
import click
def predict_pipeline(compute_unit, model_name, framework):
#todo -- force cuda usage
from transformers import pipeline
unmasker = pipeline('fill-mask', model=model_name, framework=framework)
unmasker("Hello I'm a [MASK] model.")
return unmasker
def predict_raw(compute_unit, model_name, framework):
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors=framework)
if framework == "pt":
from transformers import BertModel
import torch
model = BertModel.from_pretrained(model_name)
if torch.cuda.is_available():
device = torch.device(f"cuda:{compute_unit % torch.cuda.device_count()}")
model.to(device)
encoded_input.to(device)
model(**encoded_input)
else:
from transformers import TFBertModel
model = TFBertModel.from_pretrained(model_name)
model(encoded_input)
return model
def predict(server_mode, compute_unit, model_name, framework):
try:
e = predict_raw(compute_unit, model_name, framework)
#e = predict_pipeline(compute_unit, model_name, framework)
print(f"Loaded model {model_name}:{hex(id(e))} in {server_mode}#{compute_unit}"
f" - {os.getpid()}:{threading.current_thread().ident} using {framework}")
while True:
time.sleep(1)
except Exception as e:
print(e)
@click.command()
@click.option('--server', '-s', default='process',
help='server style -- [thread|process] process style maps to TS. thread style maps to TFS/Triton')
@click.option('--instances', '-i', default='5', help='How many compute instances to load?', type=click.INT)
@click.option('--model', '-m', default='bert-base-uncased',
help='models from https://huggingface.co/models. Example --> bert-base-uncased')
@click.option('--framework', '-f', default='pt', help='Use pt or tf. pt-->pytorch: tf-->tensorflow ')
def run_benchmark(server, instances, model, framework):
executor = ThreadPoolExecutor(max_workers=instances) if server == "thread" \
else ProcessPoolExecutor(max_workers=instances)
with executor as e:
for compute_unit in range(0, instances):
e.submit(predict, server, compute_unit, model, framework)
if __name__ == "__main__":
run_benchmark()
@dhruvaray
Copy link
Author

requirements

click
torch
torchvision
tensorflow
transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment