#!/usr/bin/env python | |
# coding: utf-8 | |
"""Create Model""" | |
import os | |
import numpy as np | |
import tensorflow as tf | |
import time | |
import cPickle as pickle | |
from rnnlib import ptb_iterator, PTBModel, SmallConfig, load_data | |
flags = tf.flags | |
logging = tf.logging | |
flags.DEFINE_string("dataset", None, "path to dataset") | |
flags.DEFINE_string("model", None, "path to output model file") | |
flags.DEFINE_boolean("padding", False, "use training data padding") | |
flags.DEFINE_integer("num_steps", None, "num_steps") | |
flags.DEFINE_integer("batch_size", None, "batch_size") | |
flags.DEFINE_integer("num_layers", None, "num_layers") | |
flags.DEFINE_integer("hidden_size", None, "hidden_size") | |
FLAGS = flags.FLAGS | |
def run_epoch(session, m, data, eval_op, verbose=False): | |
"""Runs the model on the given data.""" | |
epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps | |
start_time = time.time() | |
costs = 0.0 | |
iters = 0 | |
state = m.initial_state.eval() | |
for step, (x, y) in enumerate(ptb_iterator(data, m.batch_size, m.num_steps)): | |
cost, state, _ = session.run([m.cost, m.final_state, eval_op], | |
{m.input_data: x, | |
m.targets: y, | |
m.initial_state: state}) | |
costs += cost | |
iters += m.num_steps | |
if verbose and step % (epoch_size // 10) == 10: | |
print("%.3f perplexity: %.3f speed: %.0f wps" % | |
(step * 1.0 / epoch_size, np.exp(costs / iters), | |
iters * m.batch_size / (time.time() - start_time))) | |
return np.exp(costs / iters) | |
def padding_data(seq_data, segment_size, end_of_seq_symbol, padding_symbol): | |
it = iter(seq_data) | |
while True: | |
ended = False | |
for i in range(segment_size): | |
if ended: | |
yield(padding_symbol) | |
else: | |
s = it.next() | |
if s == end_of_seq_symbol: | |
ended = True | |
yield(s) | |
def main(unused_args): | |
dataset_path = FLAGS.dataset | |
model_data_path = FLAGS.model | |
use_padding = FLAGS.padding | |
if not dataset_path or not model_data_path: | |
raise ValueError("Must be set --train and --model") | |
train_data, valid_data, test_data = load_data(dataset_path) | |
config = SmallConfig() | |
eval_config = SmallConfig() | |
eval_config.batch_size = 1 | |
eval_config.num_steps = 1 | |
if FLAGS.num_steps: config.num_steps = FLAGS.num_steps | |
if FLAGS.batch_size: config.batch_size = FLAGS.batch_size | |
if FLAGS.num_layers: | |
eval_config.num_layers = config.num_layers = FLAGS.num_layers | |
if FLAGS.hidden_size: | |
eval_config.hidden_size = config.hidden_size = FLAGS.hidden_size | |
if use_padding: | |
train_data = list(padding_data(train_data, segment_size=config.num_steps, end_of_seq_symbol=10, padding_symbol=10)) | |
valid_data = list(padding_data(valid_data, segment_size=config.num_steps, end_of_seq_symbol=10, padding_symbol=10)) | |
tsize = len(train_data) // (config.num_steps * config.batch_size) * (config.num_steps * config.batch_size) | |
print("train data %d -> %d" % (len(train_data), tsize)) | |
train_data = train_data[:tsize] | |
with tf.Graph().as_default(), tf.Session() as session: | |
initializer = tf.random_uniform_initializer(-config.init_scale, config.init_scale) | |
with tf.variable_scope("model", reuse=None, initializer=initializer): | |
m = PTBModel(is_training=True, config=config) | |
with tf.variable_scope("model", reuse=True, initializer=initializer): | |
mvalid = PTBModel(is_training=False, config=config) | |
mtest = PTBModel(is_training=False, config=eval_config) | |
saver = tf.train.Saver() | |
if os.path.exists(model_data_path): | |
print("restore session from %s" % model_data_path) | |
saver.restore(session, model_data_path) | |
else: | |
print("initialize new session") | |
tf.initialize_all_variables().run() | |
for i in range(config.max_max_epoch): | |
lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0) | |
m.assign_lr(session, config.learning_rate * lr_decay) | |
print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) | |
train_perplexity = run_epoch(session, m, train_data, m.train_op, verbose=True) | |
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) | |
valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op()) | |
print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) | |
saver.save(session, model_data_path) | |
with open("%s.config" % model_data_path, "w") as f: | |
pickle.dump(config, f) | |
#test_perplexity = run_epoch(session, mtest, test_data, tf.no_op()) | |
#print("Test Perplexity: %.3f" % test_perplexity) | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment