Skip to content

Instantly share code, notes, and snippets.

@ezyang

ezyang/opti.py Secret

Created May 12, 2024 13:10
Show Gist options
  • Save ezyang/ff84e2d4b90bb6f9d83171b1582a237c to your computer and use it in GitHub Desktop.
Save ezyang/ff84e2d4b90bb6f9d83171b1582a237c to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from optimum.bettertransformer import BetterTransformer
torch.set_float32_matmul_precision('high')
torchscript = False
better_transformer = True
pretrained_model_name = "bert-base-uncased"
num_labels = 2
max_length = 150
do_lower_case = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained(
pretrained_model_name, num_labels=num_labels, torchscript=torchscript
)
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name, config=config
)
if better_transformer:
model = BetterTransformer.transform(model)
model.to(device)
if not torchscript:
print("torch compile model")
model = torch.compile(model)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name, do_lower_case=do_lower_case
)
dummy_input = "This is a dummy input for torch jit trace"
inputs = tokenizer.encode_plus(
dummy_input,
max_length=int(max_length),
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
if torchscript:
print("Tracing model")
traced_model = torch.jit.trace(model, (input_ids, attention_mask))
n_warm_up, n_iter = 100, 1000
for i in range(n_warm_up):
predictions = model(input_ids, attention_mask)
print(f"Warm up for {n_warm_up} iterations")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(n_iter):
predictions = model(input_ids, attention_mask)
end.record()
torch.cuda.synchronize()
mode = "torch.compile" if not torchscript else "torchscript"
print(f"Avg Time taken by {mode} model with BetterTransformer: {better_transformer} for 1 inference is {start.elapsed_time(end)/n_iter} ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment