Skip to content

Instantly share code, notes, and snippets.

@Remorax
Created March 20, 2023 00:20
Show Gist options
  • Save Remorax/85e115556d86b1ad0094453cf640a602 to your computer and use it in GitHub Desktop.
Save Remorax/85e115556d86b1ad0094453cf640a602 to your computer and use it in GitHub Desktop.
import argparse
import gc
import re
import math
import os
import time
import pickle
import torch
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--prompts_file", type=str, help="pickle file containing prompts to query")
parser.add_argument("--output_file", type=str, help="file to save output translations")
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
parser.add_argument("--max-gpu-memory", type=str, default='80GB')
parser.add_argument("--max-new-tokens", type=int, default=64)
parser.add_argument("--name", type=str, help="Name path", required=True)
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--no-repeat-ngram-size", type=int, default=4)
parser.add_argument("--num-beams", type=int, default=0)
parser.add_argument("--early-stopping", action="store_true", help="Early stopping for beam search")
return parser.parse_args()
t_start = time.time()
args = get_args()
num_tokens = args.max_new_tokens
local_rank = int(os.getenv("LOCAL_RANK", "0"))
rank = local_rank
def print_rank0(*msg):
if rank != 0:
return
print(*msg)
model_name = args.name
print_rank0(f"Loading model {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
dtype = torch.int8
max_memory = {i: args.max_gpu_memory for i in range(world_size)}
print(f'Max memory : {max_memory}')
kwargs = dict(
device_map="balanced",
max_memory=max_memory,
load_in_8bit=True
)
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
### Generate
print_rank0(f"*** Starting to generate... ***")
input_sentences = pickle.load(open(args.prompts_file, "rb"))
generate_kwargs = dict(max_new_tokens=num_tokens, num_beams=args.num_beams, early_stopping=args.early_stopping, no_repeat_ngram_size=args.no_repeat_ngram_size)
print_rank0(f"Generate args {generate_kwargs}")
def generate(inputs):
"""returns a list of zipped inputs, outputs and number of new tokens"""
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to("cuda:0")
outputs = model.generate(**input_tokens, **generate_kwargs)
input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
output_tokens_lengths = [x.shape[0] for x in outputs]
total_new_tokens = [o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return zip(inputs, outputs, total_new_tokens)
def postprocess(result):
query, response, _ = result
if response:
# Remove input query
resp = response.replace(query, "").strip()
# Remove start and end quotes
resp = re.sub("^\"","", resp)
resp = re.sub("\"$","", resp)
# Get first line as output (preceding hallucination)
response = resp.strip().split("\n")[0].strip()
return response
print_rank0("*** Running generate")
t_generate_start = time.time()
results = []
for idx in tqdm(range(0, len(input_sentences), args.batch_size)):
inputs = input_sentences[idx*args.batch_size: (idx+1)*args.batch_size]
print (inputs)
if not inputs[0]:
results.append((inputs, ""))
else:
generated = generate(inputs)
output = postprocess(generated)
print (output)
results.append(output)
t_generate_span = time.time() - t_generate_start
open(args.output_file,"w+").write("\n".join(results) + "\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment