Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active February 23, 2023 16:47
Show Gist options
  • Save justheuristic/c5adc3f49249956722e6a4002a2a1247 to your computer and use it in GitHub Desktop.
Save justheuristic/c5adc3f49249956722e6a4002a2a1247 to your computer and use it in GitHub Desktop.
import time
import torch
import transformers
from tqdm import trange
#############################
### Evaluation parameters ###
############################
batch_size = 64
input_length = 512
output_length = 32
config = transformers.OPTConfig(
hidden_size=12288, ffn_dim=12288*4, num_hidden_layers=96, num_attention_heads=96, dropout=0,
)
############################
### Initialize the model ###
############################
def init_opt_8bit_normal(config): # peak ram usage: 360GB for OPT-175B
"""A ram-inefficient way of initializing the model weights, used for assert torch.allclose"""
model = transformers.OPTForCausalLM(config)
model.model.decoder.layers = torch.quantization.quantize_dynamic(
model.model.decoder.layers,
{torch.nn.Linear: torch.quantization.get_default_qconfig('fbgemm')},
dtype=torch.qint8,
inplace=True
)
return model
def init_opt_8bit_efficient(config): # peak ram usage: 200GB for OPT-175B
"""initialize buffers from 8-bit model and convert to quantized layers on the fly"""
actual_num_layers = config.num_hidden_layers
config.num_hidden_layers = 0 # temporarily set to 0 for memory-efficient init
partial_model = transformers.OPTForCausalLM(config)
for layer_index in trange(actual_num_layers):
block = transformers.models.opt.modeling_opt.OPTDecoderLayer(config)
block = torch.quantization.quantize_dynamic(
block,
{torch.nn.Linear: torch.quantization.get_default_qconfig('fbgemm')},
dtype=torch.qint8,
inplace=True
)
with torch.no_grad():
dummy_output = block(torch.randn(16, 1, config.hidden_size))
del dummy_output
partial_model.model.decoder.layers.append(block)
assert len(partial_model.model.decoder.layers) == actual_num_layers
config.num_hidden_layers = actual_num_layers
model = partial_model
return model
print("Initializing the model")
model = init_opt_8bit_efficient(config)
# removed: loading opt-weights, requires downloaded opt weights in pytorch format
# if you're running this and have no OPT access: i've checked that the inference time does not change
# even if you use random instead of actual weights; the only thing that changes is the model init time
input_ids = torch.randint(1, 10000, (batch_size, input_length))
output_ids = torch.randint(1, 10000, (batch_size, output_length))
############################
### the actual benchmark ###
############################
with torch.no_grad():
print("Processing inputs")
t_start = time.perf_counter()
past_key_values = model(input_ids, use_cache=True).past_key_values
t_inputs_done = time.perf_counter()
print(f"Processing {batch_size}*{input_length} inputs took {t_inputs_done - t_start:.3f} seconds")
print(f"Input reading throughput: {batch_size * input_length / (t_inputs_done - t_start)} tokens / second")
print("Generating outputs")
for i in trange(output_length):
attention_mask = torch.ones(batch_size, past_key_values[0][0].shape[-2] + 1)
out = model(input_ids=output_ids[:, i: i+1], attention_mask=attention_mask,
use_cache=True, past_key_values=past_key_values)
past_key_values = out.past_key_values
del out
t_finished = time.perf_counter()
print(f"Generating {batch_size}*{output_length} tokens took {t_finished - t_inputs_done:.3f} seconds")
print(f"Generation throughput: {batch_size * output_length / (t_finished - t_inputs_done)} tokens / second")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment