Skip to content

Instantly share code, notes, and snippets.

@DOUDOU0314
Created June 18, 2021 07:52
Show Gist options
  • Save DOUDOU0314/d7899efe8ff642c2031ae4743e6edcfb to your computer and use it in GitHub Desktop.
Save DOUDOU0314/d7899efe8ff642c2031ae4743e6edcfb to your computer and use it in GitHub Desktop.
generation example of GPT-J-6B: 6B JAX-Based Transformer
import time
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2Tokenizer
import torch
import transformers
import collections
import os
import logging
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if (logger.hasHandlers()):
logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)
def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download files from a given URL.
Parameters
----------
url : str
URL where file is located
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if one already exists at this location.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified
but doesn't match).
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = os.path.join(url.split('/')[-2],url.split('/')[-1])
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
logger.info('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB', unit_scale=False, dynamic_ncols=True):
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(fname))
return fname
class Checkpoint(collections.MutableMapping):
def __init__(self):
self.checkpoint = torch.load("./gpt-j-hf/pytorch_model.bin")
print("Loaded")
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
return torch.load(self.checkpoint[key])
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return self.__dict__
def copy(self):
return self.__dict__
def main():
urls = ['https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/config.json', \
'https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/pytorch_model.bin']
for url in urls:
download(url)
print("download finished", flush=True)
config = './gpt-j-hf/config.json'
print("load model", flush=True)
model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint())
print("ok")
model.eval()
model = GPTNeoForCausalLM.from_pretrained("./gpt-j-hf")
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model.half().cuda() # This should take about 12GB of Graphics RAM, if you have a larger than 16GB gpu you don't need the half()
input_text = 'Why AutoGluon is great?'
input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
output = model.generate(
input_ids,
do_sample=True,
max_length=800,
top_p=0.7,
top_k=0,
temperature=1.0,
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
if __name__ == '__main__':
main()
a generation example of GPT-J-6B: 6B JAX-Based Transformer:
command:
python3 generation_example.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment