Skip to content

Instantly share code, notes, and snippets.

Created January 24, 2016 05:00
Show Gist options
  • Save mokemokechicken/3caa4ef52f2bb33760a7 to your computer and use it in GitHub Desktop.
Save mokemokechicken/3caa4ef52f2bb33760a7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
"""Sampling Sequence Data from model"""
import numpy as np
import tensorflow as tf
import json
import cPickle as pickle
import itertools as it
from rnnlib import PTBModel
flags = tf.flags
logging = tf.logging
flags.DEFINE_string("model", None, "path to model file")
flags.DEFINE_string("sample", "sample.txt", "path to output sampling data")
flags.DEFINE_integer("n", 10000, "number of output data")
flags.DEFINE_boolean("batch", True, "run by batch")
class GenConfig(object):
"""Small config."""
init_scale = 0.1
learning_rate = 1.0
max_grad_norm = 5
num_layers = 2
num_steps = 1
hidden_size = 20
max_epoch = 4
max_max_epoch = 13
keep_prob = 1.0
lr_decay = 0.5
batch_size = 1
vocab_size = 11
def weighted_pick(weights):
t = np.cumsum(weights)
s = np.sum(weights)
return int(np.searchsorted(t, np.random.rand(1) * s))
def sampling(session, model, primes, end_of_sample):
state = model.initial_state.eval()
for x in primes[:-1]:
state, =[model.final_state],
{model.input_data: [[x]],
model.initial_state: state})
ret = primes
cur = primes[-1]
while cur != end_of_sample:
prob, state, =[model.prob, model.final_state],
{model.input_data: [[cur]],
model.initial_state: state})
sample = weighted_pick(prob[0])
cur = sample
return ret
def weighted_pick2(weights):
ts = np.cumsum(weights, axis=1)
ss = np.sum(weights, axis=1)
ret = []
for t, s in zip(ts, ss):
ret.append(int(np.searchsorted(t, np.random.rand(1) * s)))
return np.array(ret)
def sampling2(session, model, primes, end_of_sample):
state = model.initial_state.eval()
primes = np.array(primes)
for x in primes[:-1]:
state, =[model.final_state],
{model.input_data: x.reshape((model.batch_size, 1)),
model.initial_state: state})
batch_samples = list(primes)
cur = primes[-1]
finished = np.zeros(model.batch_size, dtype=bool)
while not all(finished):
prob, state, =[model.prob, model.final_state],
{model.input_data: cur.reshape((model.batch_size, 1)),
model.initial_state: state})
sample = weighted_pick2(prob)
cur = sample
z = cur == end_of_sample
finished |= z
ret = []
for seq in zip(*batch_samples):
ret.append(list(it.takewhile(lambda x: x != end_of_sample, seq)) + [end_of_sample])
return ret
def main(unused_args):
model_file = FLAGS.model
if not model_file:
raise ValueError("Must set --model to model path")
gen_data_path = FLAGS.sample
number_of_output = FLAGS.n
if FLAGS.batch:
s2(model_file, gen_data_path, number_of_output)
s1(model_file, gen_data_path, number_of_output)
def s1(model_file, gen_data_path, number_of_output):
with open("%s.config" % model_file) as f:
config = pickle.load(f)
config.batch_size = 1
config.num_steps = 1
with tf.Graph().as_default(), tf.Session() as session:
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)
with tf.variable_scope("model", reuse=False, initializer=initializer):
m_gen = PTBModel(is_training=False, config=config)
saver = tf.train.Saver()
saver.restore(session, model_file)
output = []
for _ in range(number_of_output):
output.append(sampling(session, m_gen, [START_OF_SEQ], END_OF_SEQ))
with open(gen_data_path, "w") as f:
json.dump(output, f)
def s2(model_file, gen_data_path, number_of_output):
with open("%s.config" % model_file) as f:
config = pickle.load(f)
config.batch_size = number_of_output
config.num_steps = 1
with tf.Graph().as_default(), tf.Session() as session:
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)
with tf.variable_scope("model", reuse=False, initializer=initializer):
m_gen = PTBModel(is_training=False, config=config)
saver = tf.train.Saver()
saver.restore(session, model_file)
output = sampling2(session, m_gen, [[START_OF_SEQ] * m_gen.batch_size], END_OF_SEQ)
with open(gen_data_path, "w") as f:
json.dump(output, f)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment