#!/usr/bin/env python | |
# coding: utf-8 | |
"""Sampling Sequence Data from model""" | |
import numpy as np | |
import tensorflow as tf | |
import json | |
import cPickle as pickle | |
import itertools as it | |
from rnnlib import PTBModel | |
flags = tf.flags | |
logging = tf.logging | |
flags.DEFINE_string("model", None, "path to model file") | |
flags.DEFINE_string("sample", "sample.txt", "path to output sampling data") | |
flags.DEFINE_integer("n", 10000, "number of output data") | |
flags.DEFINE_boolean("batch", True, "run by batch") | |
FLAGS = flags.FLAGS | |
class GenConfig(object): | |
"""Small config.""" | |
init_scale = 0.1 | |
learning_rate = 1.0 | |
max_grad_norm = 5 | |
num_layers = 2 | |
num_steps = 1 | |
hidden_size = 20 | |
max_epoch = 4 | |
max_max_epoch = 13 | |
keep_prob = 1.0 | |
lr_decay = 0.5 | |
batch_size = 1 | |
vocab_size = 11 | |
def weighted_pick(weights): | |
t = np.cumsum(weights) | |
s = np.sum(weights) | |
return int(np.searchsorted(t, np.random.rand(1) * s)) | |
def sampling(session, model, primes, end_of_sample): | |
state = model.initial_state.eval() | |
for x in primes[:-1]: | |
state, = session.run([model.final_state], | |
{model.input_data: [[x]], | |
model.initial_state: state}) | |
ret = primes | |
cur = primes[-1] | |
while cur != end_of_sample: | |
prob, state, = session.run([model.prob, model.final_state], | |
{model.input_data: [[cur]], | |
model.initial_state: state}) | |
sample = weighted_pick(prob[0]) | |
ret.append(sample) | |
cur = sample | |
return ret | |
def weighted_pick2(weights): | |
ts = np.cumsum(weights, axis=1) | |
ss = np.sum(weights, axis=1) | |
ret = [] | |
for t, s in zip(ts, ss): | |
ret.append(int(np.searchsorted(t, np.random.rand(1) * s))) | |
return np.array(ret) | |
def sampling2(session, model, primes, end_of_sample): | |
state = model.initial_state.eval() | |
primes = np.array(primes) | |
for x in primes[:-1]: | |
state, = session.run([model.final_state], | |
{model.input_data: x.reshape((model.batch_size, 1)), | |
model.initial_state: state}) | |
batch_samples = list(primes) | |
cur = primes[-1] | |
finished = np.zeros(model.batch_size, dtype=bool) | |
while not all(finished): | |
prob, state, = session.run([model.prob, model.final_state], | |
{model.input_data: cur.reshape((model.batch_size, 1)), | |
model.initial_state: state}) | |
sample = weighted_pick2(prob) | |
batch_samples.append(sample) | |
cur = sample | |
z = cur == end_of_sample | |
finished |= z | |
ret = [] | |
for seq in zip(*batch_samples): | |
ret.append(list(it.takewhile(lambda x: x != end_of_sample, seq)) + [end_of_sample]) | |
return ret | |
START_OF_SEQ = 1 | |
END_OF_SEQ = 10 | |
def main(unused_args): | |
model_file = FLAGS.model | |
if not model_file: | |
raise ValueError("Must set --model to model path") | |
gen_data_path = FLAGS.sample | |
number_of_output = FLAGS.n | |
if FLAGS.batch: | |
s2(model_file, gen_data_path, number_of_output) | |
else: | |
s1(model_file, gen_data_path, number_of_output) | |
def s1(model_file, gen_data_path, number_of_output): | |
with open("%s.config" % model_file) as f: | |
config = pickle.load(f) | |
config.batch_size = 1 | |
config.num_steps = 1 | |
with tf.Graph().as_default(), tf.Session() as session: | |
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale) | |
with tf.variable_scope("model", reuse=False, initializer=initializer): | |
m_gen = PTBModel(is_training=False, config=config) | |
saver = tf.train.Saver() | |
saver.restore(session, model_file) | |
output = [] | |
for _ in range(number_of_output): | |
output.append(sampling(session, m_gen, [START_OF_SEQ], END_OF_SEQ)) | |
with open(gen_data_path, "w") as f: | |
json.dump(output, f) | |
def s2(model_file, gen_data_path, number_of_output): | |
with open("%s.config" % model_file) as f: | |
config = pickle.load(f) | |
config.batch_size = number_of_output | |
config.num_steps = 1 | |
with tf.Graph().as_default(), tf.Session() as session: | |
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale) | |
with tf.variable_scope("model", reuse=False, initializer=initializer): | |
m_gen = PTBModel(is_training=False, config=config) | |
saver = tf.train.Saver() | |
saver.restore(session, model_file) | |
output = sampling2(session, m_gen, [[START_OF_SEQ] * m_gen.batch_size], END_OF_SEQ) | |
with open(gen_data_path, "w") as f: | |
json.dump(output, f) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment