-
-
Save justheuristic/c5adc3f49249956722e6a4002a2a1247 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import transformers | |
from tqdm import trange | |
############################# | |
### Evaluation parameters ### | |
############################ | |
batch_size = 64 | |
input_length = 512 | |
output_length = 32 | |
config = transformers.OPTConfig( | |
hidden_size=12288, ffn_dim=12288*4, num_hidden_layers=96, num_attention_heads=96, dropout=0, | |
) | |
############################ | |
### Initialize the model ### | |
############################ | |
def init_opt_8bit_normal(config): # peak ram usage: 360GB for OPT-175B | |
"""A ram-inefficient way of initializing the model weights, used for assert torch.allclose""" | |
model = transformers.OPTForCausalLM(config) | |
model.model.decoder.layers = torch.quantization.quantize_dynamic( | |
model.model.decoder.layers, | |
{torch.nn.Linear: torch.quantization.get_default_qconfig('fbgemm')}, | |
dtype=torch.qint8, | |
inplace=True | |
) | |
return model | |
def init_opt_8bit_efficient(config): # peak ram usage: 200GB for OPT-175B | |
"""initialize buffers from 8-bit model and convert to quantized layers on the fly""" | |
actual_num_layers = config.num_hidden_layers | |
config.num_hidden_layers = 0 # temporarily set to 0 for memory-efficient init | |
partial_model = transformers.OPTForCausalLM(config) | |
for layer_index in trange(actual_num_layers): | |
block = transformers.models.opt.modeling_opt.OPTDecoderLayer(config) | |
block = torch.quantization.quantize_dynamic( | |
block, | |
{torch.nn.Linear: torch.quantization.get_default_qconfig('fbgemm')}, | |
dtype=torch.qint8, | |
inplace=True | |
) | |
with torch.no_grad(): | |
dummy_output = block(torch.randn(16, 1, config.hidden_size)) | |
del dummy_output | |
partial_model.model.decoder.layers.append(block) | |
assert len(partial_model.model.decoder.layers) == actual_num_layers | |
config.num_hidden_layers = actual_num_layers | |
model = partial_model | |
return model | |
print("Initializing the model") | |
model = init_opt_8bit_efficient(config) | |
# removed: loading opt-weights, requires downloaded opt weights in pytorch format | |
# if you're running this and have no OPT access: i've checked that the inference time does not change | |
# even if you use random instead of actual weights; the only thing that changes is the model init time | |
input_ids = torch.randint(1, 10000, (batch_size, input_length)) | |
output_ids = torch.randint(1, 10000, (batch_size, output_length)) | |
############################ | |
### the actual benchmark ### | |
############################ | |
with torch.no_grad(): | |
print("Processing inputs") | |
t_start = time.perf_counter() | |
past_key_values = model(input_ids, use_cache=True).past_key_values | |
t_inputs_done = time.perf_counter() | |
print(f"Processing {batch_size}*{input_length} inputs took {t_inputs_done - t_start:.3f} seconds") | |
print(f"Input reading throughput: {batch_size * input_length / (t_inputs_done - t_start)} tokens / second") | |
print("Generating outputs") | |
for i in trange(output_length): | |
attention_mask = torch.ones(batch_size, past_key_values[0][0].shape[-2] + 1) | |
out = model(input_ids=output_ids[:, i: i+1], attention_mask=attention_mask, | |
use_cache=True, past_key_values=past_key_values) | |
past_key_values = out.past_key_values | |
del out | |
t_finished = time.perf_counter() | |
print(f"Generating {batch_size}*{output_length} tokens took {t_finished - t_inputs_done:.3f} seconds") | |
print(f"Generation throughput: {batch_size * output_length / (t_finished - t_inputs_done)} tokens / second") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment