Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created August 2, 2022 09:29
Show Gist options
  • Save younesbelkada/813e2b753d7f56a52ab6e3838959f1a2 to your computer and use it in GitHub Desktop.
Save younesbelkada/813e2b753d7f56a52ab6e3838959f1a2 to your computer and use it in GitHub Desktop.
import time
import torch
import numpy as np
import argparse
from transformers import pipeline
parser = argparse.ArgumentParser(description='Benchmark pipeline runtime for int8 models')
parser.add_argument('--batch_size', default=1, type=int, help='batch_size for experiments')
parser.add_argument('--nb_runs', default=10, type=int, help='number of times for repeating experiments')
parser.add_argument('--nb_gpus', default=7, type=int, help='number of GPUs to use')
parser.add_argument('--seq_length', default=20, type=int, help='maximum number of tokens to generate')
parser.add_argument('--max_memory', default="30GB", type=str, help='Maximum memory to use for each GPU')
parser.add_argument('--model', type=str)
args = parser.parse_args()
NB_RUNS = args.nb_runs
BATCH_SIZE=args.batch_size
def get_input():
input_test = ["test" for _ in range(BATCH_SIZE)]
return input_test
def run_pipeline():
total_time = []
for _ in range(NB_RUNS):
start = time.perf_counter()
_ = pipe(input_test)
end = time.perf_counter()
torch.cuda.synchronize()
total_time.append(end-start)
return total_time
def get_gpus_max_memory(max_memory, n_gpus):
assert n_gpus <= torch.cuda.device_count(), "You are requesting more GPUs than available GPUs"
max_memory = {i: max_memory for i in range(n_gpus)}
return max_memory
input_test = get_input()
mapping_gpu_memory = get_gpus_max_memory(args.max_memory, args.nb_gpus)
pipe = pipeline(model=args.model, model_kwargs= {"device_map": "auto", "torch_dtype": torch.float16, "max_memory":mapping_gpu_memory}, max_new_tokens=args.seq_length, batch_size=args.batch_size, use_fast=False)
# Do a dummy run
_ = pipe(input_test)
total_time = run_pipeline()
print("Time elapsed: {} +- {}".format(np.mean(total_time), np.std(total_time)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment