Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
#!/usr/bin/env python
# coding: utf-8
"""Create Model"""
import os
import numpy as np
import tensorflow as tf
import time
import cPickle as pickle
from rnnlib import ptb_iterator, PTBModel, SmallConfig, load_data
flags = tf.flags
logging = tf.logging
flags.DEFINE_string("dataset", None, "path to dataset")
flags.DEFINE_string("model", None, "path to output model file")
flags.DEFINE_boolean("padding", False, "use training data padding")
flags.DEFINE_integer("num_steps", None, "num_steps")
flags.DEFINE_integer("batch_size", None, "batch_size")
flags.DEFINE_integer("num_layers", None, "num_layers")
flags.DEFINE_integer("hidden_size", None, "hidden_size")
def run_epoch(session, m, data, eval_op, verbose=False):
"""Runs the model on the given data."""
epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
start_time = time.time()
costs = 0.0
iters = 0
state = m.initial_state.eval()
for step, (x, y) in enumerate(ptb_iterator(data, m.batch_size, m.num_steps)):
cost, state, _ =[m.cost, m.final_state, eval_op],
{m.input_data: x,
m.targets: y,
m.initial_state: state})
costs += cost
iters += m.num_steps
if verbose and step % (epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" %
(step * 1.0 / epoch_size, np.exp(costs / iters),
iters * m.batch_size / (time.time() - start_time)))
return np.exp(costs / iters)
def padding_data(seq_data, segment_size, end_of_seq_symbol, padding_symbol):
it = iter(seq_data)
while True:
ended = False
for i in range(segment_size):
if ended:
s =
if s == end_of_seq_symbol:
ended = True
def main(unused_args):
dataset_path = FLAGS.dataset
model_data_path = FLAGS.model
use_padding = FLAGS.padding
if not dataset_path or not model_data_path:
raise ValueError("Must be set --train and --model")
train_data, valid_data, test_data = load_data(dataset_path)
config = SmallConfig()
eval_config = SmallConfig()
eval_config.batch_size = 1
eval_config.num_steps = 1
if FLAGS.num_steps: config.num_steps = FLAGS.num_steps
if FLAGS.batch_size: config.batch_size = FLAGS.batch_size
if FLAGS.num_layers:
eval_config.num_layers = config.num_layers = FLAGS.num_layers
if FLAGS.hidden_size:
eval_config.hidden_size = config.hidden_size = FLAGS.hidden_size
if use_padding:
train_data = list(padding_data(train_data, segment_size=config.num_steps, end_of_seq_symbol=10, padding_symbol=10))
valid_data = list(padding_data(valid_data, segment_size=config.num_steps, end_of_seq_symbol=10, padding_symbol=10))
tsize = len(train_data) // (config.num_steps * config.batch_size) * (config.num_steps * config.batch_size)
print("train data %d -> %d" % (len(train_data), tsize))
train_data = train_data[:tsize]
with tf.Graph().as_default(), tf.Session() as session:
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale)
with tf.variable_scope("model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, config=config)
with tf.variable_scope("model", reuse=True, initializer=initializer):
mvalid = PTBModel(is_training=False, config=config)
mtest = PTBModel(is_training=False, config=eval_config)
saver = tf.train.Saver()
if os.path.exists(model_data_path):
print("restore session from %s" % model_data_path)
saver.restore(session, model_data_path)
print("initialize new session")
for i in range(config.max_max_epoch):
lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
m.assign_lr(session, config.learning_rate * lr_decay)
print("Epoch: %d Learning rate: %.3f" % (i + 1,
train_perplexity = run_epoch(session, m, train_data, m.train_op, verbose=True)
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op())
print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)), model_data_path)
with open("%s.config" % model_data_path, "w") as f:
pickle.dump(config, f)
#test_perplexity = run_epoch(session, mtest, test_data, tf.no_op())
#print("Test Perplexity: %.3f" % test_perplexity)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment