Skip to content

Instantly share code, notes, and snippets.

@rossgoodwin
Created May 23, 2019 06:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rossgoodwin/d21d97712435c07c78f9063cba5f1cf1 to your computer and use it in GitHub Desktop.
Save rossgoodwin/d21d97712435c07c78f9063cba5f1cf1 to your computer and use it in GitHub Desktop.
Based on Kyle McDonald's jupyter notebook code for sampling GPT-2 models
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder
from sys import argv, stdin
# In[2]:
# !ln -s ../models models # hack to make models "appear" in two places
# In[2]:
EXP_NAME, TEMP_STR = argv
SEED = stdin.read()
model_name = 'poet3'
seed = None
nsamples = 32
batch_size = 32
length = None
temperature = float(TEMP_STR) # 0 is deterministic
top_k = 0 # 0 means no restrictions
assert nsamples % batch_size == 0
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
# In[3]:
sess = tf.InteractiveSession()
# replace with this in script:
# with tf.Session(graph=tf.Graph()) as sess:
# In[4]:
context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
# In[9]:
from utils.list_all_files import *
import unicodedata
import os, re, random
mapping = {
'\xa0': ' ',
'Æ': 'AE',
'æ': 'ae',
'è': 'e',
'é': 'e',
'ë': 'e',
'ö': 'o',
'–': '-',
'—': '-',
'‘': "'",
'’': "'",
'“': '"',
'”': '"'
}
def remove_special(text):
return ''.join([mapping[e] if e in mapping else e for e in text])
def strip_word(word):
word = re.sub(r'^\W*|\W*$', '', word).lower()
return word
# basenames = []
# all_poems = {}
# total_lines = 0
# words = set()
# for fn in list_all_files('../../scraping/poetry/output'):
# with open(fn) as f:
# original = open(fn).read()
# text = remove_special(original).split('\n')
# poem = text[3:]
# basename = os.path.basename(fn)
# basename = os.path.splitext(basename)[0]
# basenames.append(basename)
# all_poems[basename] = {
# 'url': text[0],
# 'title': text[1],
# 'author': text[2],
# 'poem': poem
# }
# total_lines += len(poem)
# poem = '\n'.join(poem)
# words.update([strip_word(e) for e in poem.split()])
# words.remove('')
# words = list(words)
# print(total_lines)
# In[10]:
def titlecase_word(word):
return word[0].upper() + word[1:]
# titlecase_word("carpenter's"), "carpenter's".title()
# In[11]:
def random_chunk(array, length):
start = random.randint(0, max(0, len(array) - length - 1))
return array[start:start+length]
def random_item(array):
return array[random.randint(0, len(array) - 1)]
# random_chunk(all_poems[basenames[0]]['poem'], 2), titlecase_word(random_item(words))
# In[12]:
# seeds = '''
# blue
# epoch
# ethereal
# ineffable
# iridescent
# nefarious
# oblivion
# quiver
# solitude
# sonorous
# '''.split()
# len(seeds)
# In[13]:
#from utils.progress import progress
# In[26]:
def clean(text):
return text.split('<|endoftext|>')[0]
def generate(raw_text):
# inspiration = remove_special(inspiration).strip()
# seed = titlecase_word(seed).strip()
# raw_text = inspiration + '\n' + seed
context_tokens = enc.encode(raw_text)
n_context = len(context_tokens)
results = []
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})
for sample in out:
text = enc.decode(sample[n_context:])
result = raw_text + text
results.append(result)
print( '\n'*3+'='*64+'\n'*3 )
print( result )
return results
if __name__ == '__main__':
res_list = generate(SEED)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment