Skip to content

Instantly share code, notes, and snippets.

@gante
Last active December 19, 2022 16:39
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gante/f0017e3f13ac11b0c02e4e4db351f52f to your computer and use it in GitHub Desktop.
Save gante/f0017e3f13ac11b0c02e4e4db351f52f to your computer and use it in GitHub Desktop.
Benchmark Hugging Face generation for PT, FLAX, TF with Eager Execution, and TF with XLA.
import os
import time
from datetime import timedelta
from functools import wraps, partial
from tqdm import tqdm
# JAX imports and settings
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
# PyTorch imports and settings
import torch
from transformers.testing_utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = True # All frameworks using TF32
# TF imports and settings
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# Transformers imports last, as they would import the frameworks (and we need to set some options before that)
from transformers import (
AutoTokenizer, AutoConfig,
MODEL_FOR_CAUSAL_LM_MAPPING, AutoModelForCausalLM, TFAutoModelForCausalLM, FlaxAutoModelForCausalLM,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM,
)
MODEL_NAME = "distilgpt2"
RUN_FLAX = True
NUM_RUNS = 100
SAMPLE = False
NUM_BEAMS = 1
MAX_NEW_TOKENS = 64
BATCH_SIZE = 1
PAD_MULTIPLE = 64 # XLA has to retrace for each input length, pads the inputs so as to be multiples of this
TEMPERATURE = 2.0 # used when SAMPLE is True
TOP_K = 50 # used when SAMPLE is True
# Don't check results for generations that are prone to small differences: long generations and many beams.
CHECK_RESULTS = not SAMPLE and (NUM_BEAMS <= 16 and MAX_NEW_TOKENS <= 64)
# different lengths to ensure XLA is able to handle different number of input tokens
INPUT_EXAMPLES = [
"The dog",
"Yesterday, there was",
"Roses are red, Violets are blue",
"'The quick brown fox jumps over the lazy dog' - he said.",
"All human beings are born free and equal in dignity and rights. They are endowed with reason and conscience and "\
"should act towards one another in a spirit of brotherhood."
]
def measure_time(function):
""" Decorator to print execution time of a function """
@wraps(function)
def decorated(inputs):
start = time.time()
fn_output = function(inputs)
end = time.time()
duration = timedelta(seconds=end - start)
return fn_output, duration
return decorated
def get_tokenizer(model_name):
config = AutoConfig.from_pretrained(model_name)
if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING:
tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="</s>", padding_side="left")
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
def get_inputs(tokenizer, index, return_tensors, use_xla=False):
input_sentences = [INPUT_EXAMPLES[i % len(INPUT_EXAMPLES)] for i in range(index, index + BATCH_SIZE)]
tokenizer_kwargs = {"return_tensors": return_tensors, "padding": True, "truncation": True}
if use_xla:
tokenizer_kwargs.update({"pad_to_multiple_of": PAD_MULTIPLE})
return tokenizer(input_sentences, **tokenizer_kwargs)
def get_model(framework, model_name):
config = AutoConfig.from_pretrained(model_name)
model = None
if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING:
if framework == "tf":
model = TFAutoModelForCausalLM.from_pretrained(model_name)
elif framework == "pt":
model = AutoModelForCausalLM.from_pretrained(model_name)
elif framework == "flax":
model = FlaxAutoModelForCausalLM.from_pretrained(model_name)
elif type(config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
if framework == "tf":
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
elif framework == "pt":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
elif framework == "flax":
if model_name == "t5-3b": # no FLAX weights for this one, but it can be loaded from PT weights
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name, from_pt=True)
else:
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
if model is None:
raise ValueError(f"Invalid framework ({framework}) / model ({model_name}) combination.")
return model
def print_status(all_outputs, all_durations):
print(f"Execution time -- 1st call: {all_durations[0]/1000:.2f} ms")
all_durations = all_durations[1:]
print(f"Execution time -- mean: {(sum(all_durations)/len(all_durations))/1000:.2f} ms")
all_outputs = all_outputs[1:]
try:
mean_length = sum([out.shape[1] for out in all_outputs]) / len(all_outputs)
except:
mean_length = sum([out.sequences.shape[1] for out in all_outputs]) / len(all_outputs)
print(f"Outputs -- mean length: {mean_length:.2f} tokens")
def main_tf_eager():
model = get_model("tf", MODEL_NAME)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = get_tokenizer(MODEL_NAME)
@measure_time
def _generate(inputs):
inputs.update({"do_sample": SAMPLE, "num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
if SAMPLE:
inputs.update({"temperature": TEMPERATURE, "top_k": TOP_K})
return model.generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(int(NUM_RUNS / 10))): # TF Eager is very slow :(
inputs = get_inputs(tokenizer, index=i, return_tensors="tf")
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
return all_outputs
def main_tf_xla():
model = get_model("tf", MODEL_NAME)
model.config.pad_token_id = model.config.eos_token_id
tokenizer = get_tokenizer(MODEL_NAME)
xla_generate = tf.function(model.generate, jit_compile=True)
@measure_time
def _generate(inputs):
inputs.update({"do_sample": SAMPLE, "num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
if SAMPLE:
inputs.update({"temperature": TEMPERATURE, "top_k": TOP_K})
return xla_generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(NUM_RUNS)):
inputs = get_inputs(tokenizer, index=i, return_tensors="tf", use_xla=True)
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
return all_outputs
def main_pt():
model = get_model("pt", MODEL_NAME)
model.config.pad_token_id = model.config.eos_token_id
model.to(torch_device)
tokenizer = get_tokenizer(MODEL_NAME)
@measure_time
def _generate(inputs):
with torch.no_grad():
inputs.update({"do_sample": SAMPLE, "num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
if SAMPLE:
inputs.update({"temperature": TEMPERATURE, "top_k": TOP_K})
return model.generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(NUM_RUNS)):
inputs = get_inputs(tokenizer, index=i, return_tensors="pt")
inputs.to(torch_device)
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
del model
torch.cuda.empty_cache()
return all_outputs
def main_flax():
model = get_model("flax", MODEL_NAME)
tokenizer = get_tokenizer(MODEL_NAME)
model.config.pad_token_id = model.config.eos_token_id
generate_kwargs = {"do_sample": SAMPLE, "num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS}
if SAMPLE:
generate_kwargs.update({"temperature": TEMPERATURE, "top_k": TOP_K})
jit_generate = jax.jit(partial(model.generate, **generate_kwargs))
@measure_time
def _generate(inputs):
return jit_generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(NUM_RUNS)):
inputs = get_inputs(tokenizer, index=i, return_tensors="np", use_xla=True)
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
return all_outputs
def check_outputs(pt_out, eager_out, xla_out, flax_out):
if CHECK_RESULTS:
print("\n\nCHECKING OUTPUTS")
tokenizer = get_tokenizer(MODEL_NAME)
num_sentences = len(INPUT_EXAMPLES)
pt_decoded = [
tokenizer.decode(out[0, :], skip_special_tokens=True) for i, out in enumerate(pt_out) if i <= num_sentences
]
eager_decoded = [
tokenizer.decode(out[0, :], skip_special_tokens=True) for i, out in enumerate(eager_out) if i <= num_sentences
]
xla_decoded = [
tokenizer.decode(out[0, :], skip_special_tokens=True) for i, out in enumerate(xla_out) if i <= num_sentences
]
if flax_out:
flax_decoded = [
tokenizer.decode(out.sequences[0, :], skip_special_tokens=True) for i, out in enumerate(flax_out) if i <= num_sentences
]
failed_checks = False
for i in range(num_sentences):
failed_this_check = False
if pt_decoded[i] != eager_decoded[i]:
print(f"FAILED: pt_decoded[{i}] != eager_decoded[{i}]")
failed_this_check = True
if pt_decoded[i] != xla_decoded[i]:
print(f"FAILED: pt_decoded[{i}] != xla_decoded[{i}]")
failed_this_check = True
if flax_out:
if pt_decoded[i] != flax_decoded[i]:
print(f"FAILED: pt_decoded[{i}] != flax_decoded[{i}]")
failed_this_check = True
if failed_this_check:
print(f"PT : {pt_decoded[i]}")
print(f"EAGER : {eager_decoded[i]}")
print(f"XLA : {xla_decoded[i]}")
if flax_out:
print(f"FLAX : {flax_decoded[i]}")
print("")
failed_checks |= failed_this_check
if not failed_checks:
print("All outputs match")
if __name__ == "__main__":
if RUN_FLAX:
print("\n\nFLAX")
flax_out = main_flax()
else:
flax_out = None
print("\n\nPYTORCH")
pt_out = main_pt()
print("\n\nTF (NO XLA)")
eager_out = main_tf_eager()
print("\n\nTF (XLA)")
xla_out = main_tf_xla()
check_outputs(pt_out, eager_out, xla_out, flax_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment