Last active
September 9, 2020 09:41
-
-
Save dhruvaray/18c0d1d4f475b2072cc5bba45641e9e7 to your computer and use it in GitHub Desktop.
dummymodelserving
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
import time | |
import threading | |
import os | |
import click | |
def predict_pipeline(compute_unit, model_name, framework): | |
#todo -- force cuda usage | |
from transformers import pipeline | |
unmasker = pipeline('fill-mask', model=model_name, framework=framework) | |
unmasker("Hello I'm a [MASK] model.") | |
return unmasker | |
def predict_raw(compute_unit, model_name, framework): | |
from transformers import BertTokenizer | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
text = "Replace me by any text you'd like." | |
encoded_input = tokenizer(text, return_tensors=framework) | |
if framework == "pt": | |
from transformers import BertModel | |
import torch | |
model = BertModel.from_pretrained(model_name) | |
if torch.cuda.is_available(): | |
device = torch.device(f"cuda:{compute_unit % torch.cuda.device_count()}") | |
model.to(device) | |
encoded_input.to(device) | |
model(**encoded_input) | |
else: | |
from transformers import TFBertModel | |
model = TFBertModel.from_pretrained(model_name) | |
model(encoded_input) | |
return model | |
def predict(server_mode, compute_unit, model_name, framework): | |
try: | |
e = predict_raw(compute_unit, model_name, framework) | |
#e = predict_pipeline(compute_unit, model_name, framework) | |
print(f"Loaded model {model_name}:{hex(id(e))} in {server_mode}#{compute_unit}" | |
f" - {os.getpid()}:{threading.current_thread().ident} using {framework}") | |
while True: | |
time.sleep(1) | |
except Exception as e: | |
print(e) | |
@click.command() | |
@click.option('--server', '-s', default='process', | |
help='server style -- [thread|process] process style maps to TS. thread style maps to TFS/Triton') | |
@click.option('--instances', '-i', default='5', help='How many compute instances to load?', type=click.INT) | |
@click.option('--model', '-m', default='bert-base-uncased', | |
help='models from https://huggingface.co/models. Example --> bert-base-uncased') | |
@click.option('--framework', '-f', default='pt', help='Use pt or tf. pt-->pytorch: tf-->tensorflow ') | |
def run_benchmark(server, instances, model, framework): | |
executor = ThreadPoolExecutor(max_workers=instances) if server == "thread" \ | |
else ProcessPoolExecutor(max_workers=instances) | |
with executor as e: | |
for compute_unit in range(0, instances): | |
e.submit(predict, server, compute_unit, model, framework) | |
if __name__ == "__main__": | |
run_benchmark() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
requirements
click
torch
torchvision
tensorflow
transformers