Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
GPT-2 345M download_model fork
import os
import sys
import requests
from tqdm import tqdm
if len(sys.argv) != 3:
print('You must enter the model name as a parameter, e.g.: download_model.py 117M')
sys.exit(1)
model = sys.argv[1]
directory = sys.argv[2]
subdir = os.path.join(directory, model)
if not os.path.exists(subdir):
os.makedirs(subdir)
subdir = subdir.replace('\\','/') # needed for Windows
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True)
with open(os.path.join(subdir, filename), 'wb') as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment