Skip to content

Instantly share code, notes, and snippets.

@allo-
Created January 16, 2020 23:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allo-/0c7894c86e280cad7b3ff56c32bad377 to your computer and use it in GitHub Desktop.
Save allo-/0c7894c86e280cad7b3ff56c32bad377 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import json
import os
import numpy as np
import tensorflow as tf
import time
import model, sample, encoder_sp as encoder
from accumulate import AccumulatingOptimizer
SAMPLE_DIR = 'samples'
parser = argparse.ArgumentParser(
description='Strip unneeded data from GPT-2 Checkpoints',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117MSP', help='Pretrained model name')
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size')
parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size')
parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.')
parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file')
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/')
def maketree(path):
try:
os.makedirs(path)
except:
pass
def main():
args = parser.parse_args()
enc = encoder.get_encoder(args.model_name)
hparams = model.default_hparams()
with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
CHECKPOINT_DIR = os.path.join('models', args.model_name, 'checkpoint')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
context = tf.placeholder(tf.int32, [args.batch_size, None])
output = model.model(hparams=hparams, X=context)
train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
if args.restore_from == 'latest':
ckpt = tf.train.latest_checkpoint(
os.path.join(CHECKPOINT_DIR, args.run_name))
if ckpt is None:
# Get fresh GPT weights if new run.
ckpt = tf.train.latest_checkpoint(
os.path.join('models', args.model_name))
elif args.restore_from == 'fresh':
ckpt = tf.train.latest_checkpoint(
os.path.join('models', args.model_name))
else:
ckpt = tf.train.latest_checkpoint(args.restore_from)
if ckpt is not None:
print('Loading checkpoint', ckpt)
saver.restore(sess, ckpt)
# Saver that stores the complete state
saver = tf.train.Saver()
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print("The model has", total_parameters, "parameters")
# Saver that stores only the trainable variables
all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars
saver = tf.train.Saver(
var_list=all_vars,
max_to_keep=1)
def save():
maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
print(
'Saving',
os.path.join(CHECKPOINT_DIR, args.run_name,
'model-reduced'))
saver.save(
sess,
os.path.join(CHECKPOINT_DIR, args.run_name, 'model-reduced'))
print("Saving without training parameters")
save()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment