Created
July 3, 2018 04:14
Star
You must be signed in to star a gist
Task: Fill in the blank in a sentence.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Train | |
python 3.6.4 | |
tensorflow 0.12.1 | |
2018-06-30 | |
''' | |
import ipdb as pdb | |
import os | |
import utils | |
import pickle | |
from time import strftime, gmtime | |
from hparams import hp | |
import data | |
import tensorflow as tf | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
def timenow(): | |
# Get current time | |
return strftime("%Y-%m-%d_%H_%M_%S", gmtime()) | |
class Model: | |
def __init__(self, hp, logger, wdict, idict, save_dir): | |
self.hp = hp | |
self.logger = logger | |
self.wdict = wdict | |
self.idict = idict | |
self.save_dir = save_dir | |
self.logger.info('<< Model initiated >>') | |
with open(os.path.join(self.save_dir, 'params.txt'), 'w') as f: | |
f.write(f'--- {timenow()} ---\n') | |
for k in self.hp.values().keys(): | |
f.write(f'{k}: {self.hp.values()[k]}\n') | |
def _create_variables(self): | |
# batch_size x utt_len | |
self.global_step = tf.get_variable( | |
'global_step', initializer=tf.constant(0), trainable=False) | |
self._X_f = tf.placeholder(tf.int32, [None, None], name='forward_X') | |
self._X_b = tf.placeholder(tf.int32, [None, None], name='backward_X') | |
self._X_b = tf.reverse(self._X_b, [0], name='reverse_backward_x') | |
self._Y = tf.placeholder(tf.int32, [None], name='Y') | |
self.w = tf.get_variable('final_weights', | |
[self.hp.VOCAB_SIZE, 2*self.hp.HID_SIZE], | |
tf.float32, | |
initializer=tf.contrib.layers.xavier_initializer()) | |
self.b = tf.get_variable('final_biases', [self.hp.VOCAB_SIZE]) | |
self.Y = tf.one_hot(self._Y, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='onehot_y') | |
if self.hp.INCLUDE_EMBEDDING: | |
self._create_embedding() | |
else: | |
self.X_f = tf.one_hot(self._X_f, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='foward_onehot_x') | |
self.X_b = tf.one_hot(self._X_b, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='backward_onehot_x') | |
def _create_embedding(self): | |
with tf.name_scope('embed'): | |
self.embed_matrix = tf.get_variable( # vocab_size x embed_size | |
'embed_matrix', [self.hp.VOCAB_SIZE, self.hp.EMBED_SIZE], | |
initializer=tf.random_uniform_initializer()) | |
# (batch_size x context_win x vocab_size)x(vocab_size x embed_size) | |
# => (batch_size x context_win x embed_size) | |
self.X_f = tf.nn.embedding_lookup( | |
self.embed_matrix, self._X_f, name='foward_embed') | |
self.X_b = tf.nn.embedding_lookup( | |
self.embed_matrix, self._X_b, name='backward_embed') | |
def _create_net_loss(self): | |
''' | |
forward network | |
h1 = forward_RNN(X) | |
h2 = backward_RNN(X) | |
C = [h1; h2], # w*hid_dim x batch_size | |
yhat = softmax(w*C + b) # vocab_size x batch_size | |
''' | |
with tf.name_scope('feedforward_network'): | |
h1 = self._RNN(direction='forward') # (batch_size x hid_dim) | |
h2 = self._RNN(direction='backward') # (batch_size x hid_dim) | |
C = tf.concat([h1, h2], axis=1) # (batch_size x 2*hid_dim) | |
logits = tf.matmul(C, tf.transpose(self.w)) + self.b | |
yprob = tf.nn.softmax(logits) # (batch_size x 1) | |
with tf.name_scope('Loss'): | |
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( | |
labels=self.Y, logits=logits, name='cross_entropy')) | |
self.entropy = -tf.reduce_sum(tf.log(yprob) * self.Y, | |
axis=1, keepdims=True) | |
self.pred = tf.argmax(logits, 1, 'pred') | |
self.correct_pred = tf.equal( | |
self.pred, tf.argmax(self.Y, 1), name='correct_pred') | |
self.accuracy = tf.reduce_mean( | |
tf.cast(self.correct_pred, tf.float32), name='accuracy') | |
def _RNN(self, direction='forward'): | |
''' | |
**Notes on tf.nn.dynamic_rnn** | |
- 'x' can have shape (batch_size)x(time)x(dimension), if time_major=False or | |
(time)x(batch_size)x(dimension), if time_major=True | |
- 'outputs' can have the same shape as 'x' | |
(batch_size)x(time)x(dimension), if time_major=False or | |
(time)x(batch_size)x(dimension), if time_major=True | |
- 'states' is the final state, determined by batch and hidden_dim | |
''' | |
if direction is 'forward': | |
with tf.variable_scope('forward_RNN'): | |
cell = tf.contrib.rnn.BasicRNNCell( | |
self.hp.HID_SIZE) # Make RNNCell | |
outputs, states = tf.nn.dynamic_rnn( | |
cell, self.X_f, time_major=True, dtype=tf.float32) | |
elif direction is 'backward': | |
with tf.variable_scope('backward_RNN'): | |
cell = tf.contrib.rnn.BasicRNNCell( | |
self.hp.HID_SIZE) # Make RNNCell | |
outputs, states = tf.nn.dynamic_rnn( | |
cell, self.X_b, time_major=True, dtype=tf.float32) | |
# with tf.variable_scope('Cell'): | |
# cell = tf.contrib.rnn.MultiRNNCell( | |
# [tf.contrib.rnn.BasicLSTMCell(n_hidden) for _ in range(4)]) | |
# outputs, states = tf.nn.dynamic_rnn( | |
# cell, x, time_major=False, dtype=tf.float32) | |
last_output = outputs[-1] # (batch_size x vocab_size) | |
last_states = states[-1] # (batch_size x hid_dim) | |
# TODO: figure out if last_output == last_states | |
return last_output # (batch_size x hid_dim) | |
def _create_optimizer(self): | |
self.optimizer = tf.train.AdamOptimizer(self.hp.LR).minimize( | |
self.loss, global_step=self.global_step) | |
def _add_summary(self): | |
tf.summary.scalar('train_loss', self.loss) | |
tf.summary.scalar('train_accuracy', self.accuracy) | |
self.merged_summary = tf.summary.merge_all() | |
def build_graph(self): | |
tf.reset_default_graph() | |
self._create_variables() | |
self._create_net_loss() | |
self._create_optimizer() | |
self._add_summary() | |
def train(self): | |
saver = tf.train.Saver(max_to_keep=10) | |
model_dir = os.path.join(self.save_dir, 'checkpoints') | |
utils.safe_mkdir(model_dir) | |
train_gen = data.gen_batch( | |
self.hp.TRAIN_CLEAN_FILE, self.wdict, randomize=True) | |
n_all_utts = 0 | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
writer = tf.summary.FileWriter( | |
os.path.join(self.save_dir, 'summary'), sess.graph) | |
for istep in range(self.hp.NUM_TRAIN_STEPS): | |
try: | |
batch_f, batch_b, batch_y = next(train_gen) | |
except: | |
train_gen = data.gen_batch( | |
self.hp.TRAIN_CLEAN_FILE, self.wdict, randomize=True) | |
batch_f, batch_b, batch_y = next(train_gen) | |
n_all_utts += 1 # record how many iterated all utterences | |
pdb.set_trace() | |
_, train_loss, train_acc, summary = sess.run( | |
[self.optimizer, self.loss, self.accuracy, self.merged_summary], | |
{self._X_f: batch_f, self._X_b: batch_b, self._Y: batch_y}) | |
print(istep, self.hp.NUM_TRAIN_STEPS) | |
def main(): | |
# Set logger | |
# SAVE_DIR = './result_' + timenow() | |
SAVE_DIR = './result' | |
utils.safe_mkdir(SAVE_DIR) | |
logger = utils.set_logger(SAVE_DIR) | |
# Preprocess data | |
vocab, wdict, idict, fdict = data.preprocess(hp.TRAIN_FILE, hp.TEST_FILE, | |
hp.MIN_SENT_LEN, hp.VOCAB_SIZE, | |
hp.DATA_DIR) | |
with open(SAVE_DIR+'/dicts.pckl', 'wb') as pckl: | |
pickle.dump({'vocab': vocab, 'word2idx': wdict, | |
'idx2word': idict, 'freqdict': fdict}, pckl) | |
# Run model | |
M = Model(hp, logger, wdict, idict, SAVE_DIR) | |
M.build_graph() | |
M.train() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment