Skip to content

Instantly share code, notes, and snippets.

@csvoss
Created February 16, 2019 07:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csvoss/45e4e244f10707a733586e3e70ba27b0 to your computer and use it in GitHub Desktop.
Save csvoss/45e4e244f10707a733586e3e70ba27b0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Modded by Chelsea to accept filenames instead of paragraphs.
"""
import fire
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder
def interact_model(
model_name='117M',
seed=None,
nsamples=1,
batch_size=None,
length=None,
temperature=1,
top_k=0,
):
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
while True:
filename = input("Filename >>> ")
while not filename:
filename = input("Filename >>> ")
with open(filename, 'r') as fi:
raw_text = fi.read()
print(raw_text)
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if __name__ == '__main__':
fire.Fire(interact_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment