Skip to content

Instantly share code, notes, and snippets.

@dhruvaray
Last active January 23, 2022 03:19
Show Gist options
  • Save dhruvaray/66c277777d8b47dea0d885c555c9be00 to your computer and use it in GitHub Desktop.
Save dhruvaray/66c277777d8b47dea0d885c555c9be00 to your computer and use it in GitHub Desktop.
Use torch multiprocessing & share_memory APIs
import torch.multiprocessing as mp
from transformers import BertModel
import torch
import time
import click
def load_model(model_name):
model = BertModel.from_pretrained(model_name)
#device = torch.device(f"cuda:0")
#model.to(device)
model.share_memory()
return model
def predict(compute_unit, shared_model_proxy, model_name):
try:
with torch.no_grad():
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
text = f"Replace me by any text you'd like.{compute_unit}"
encoded_input = tokenizer(text, return_tensors="pt")
#device = torch.device(f"cuda:{compute_unit % torch.cuda.device_count()}")
#encoded_input.to(device)
shared_model_proxy(**encoded_input)
print(f"Loaded model & inferred {model_name} for #{compute_unit}")
while True:
time.sleep(1)
except Exception as e:
print(e)
@click.command()
@click.option('--instances', '-i', default='2', help='How many compute instances to load?', type=click.INT)
@click.option('--model', '-m', default='bert-base-uncased',
help='models from https://huggingface.co/models. Example --> bert-base-uncased')
def run_benchmark(instances, model):
processes = []
shared_model_proxy = None
for id, compute_unit in enumerate(range(0, instances)):
if 0 == id:
shared_model_proxy = load_model(model)
p = mp.Process(target=predict, args=(compute_unit, shared_model_proxy, model,))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
#mp.set_start_method('forkserver', force=True)
run_benchmark()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment