Created
July 3, 2018 04:16
-
-
Save jaekookang/86763712a37249db6192e7134eba7d73 to your computer and use it in GitHub Desktop.
Task: Fill in the blank in a sentence / training script snippet (in progress)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Train | |
python 3.6.4 | |
tensorflow 0.12.1 | |
2018-06-30 | |
''' | |
import ipdb as pdb | |
import os | |
import utils | |
import pickle | |
from time import strftime, gmtime | |
from hparams import hp | |
import data | |
import tensorflow as tf | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
def timenow(): | |
# Get current time | |
return strftime("%Y-%m-%d_%H_%M_%S", gmtime()) | |
class Model: | |
def __init__(self, hp, logger, wdict, idict, save_dir): | |
self.hp = hp | |
self.logger = logger | |
self.wdict = wdict | |
self.idict = idict | |
self.save_dir = save_dir | |
self.logger.info('<< Model initiated >>') | |
with open(os.path.join(self.save_dir, 'params.txt'), 'w') as f: | |
f.write(f'--- {timenow()} ---\n') | |
for k in self.hp.values().keys(): | |
f.write(f'{k}: {self.hp.values()[k]}\n') | |
def _create_variables(self): | |
# batch_size x utt_len | |
self.global_step = tf.get_variable( | |
'global_step', initializer=tf.constant(0), trainable=False) | |
self._X_f = tf.placeholder(tf.int32, [None, None], name='forward_X') | |
self._X_b = tf.placeholder(tf.int32, [None, None], name='backward_X') | |
self._X_b = tf.reverse(self._X_b, [0], name='reverse_backward_x') | |
self._Y = tf.placeholder(tf.int32, [None], name='Y') | |
self.w = tf.get_variable('final_weights', | |
[self.hp.VOCAB_SIZE, 2*self.hp.HID_SIZE], | |
tf.float32, | |
initializer=tf.contrib.layers.xavier_initializer()) | |
self.b = tf.get_variable('final_biases', [self.hp.VOCAB_SIZE]) | |
self.Y = tf.one_hot(self._Y, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='onehot_y') | |
if self.hp.INCLUDE_EMBEDDING: | |
self._create_embedding() | |
else: | |
self.X_f = tf.one_hot(self._X_f, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='foward_onehot_x') | |
self.X_b = tf.one_hot(self._X_b, depth=self.hp.VOCAB_SIZE, | |
on_value=1.0, off_value=0.0, axis=-1, | |
name='backward_onehot_x') | |
def _create_embedding(self): | |
with tf.name_scope('embed'): | |
self.embed_matrix = tf.get_variable( # vocab_size x embed_size | |
'embed_matrix', [self.hp.VOCAB_SIZE, self.hp.EMBED_SIZE], | |
initializer=tf.random_uniform_initializer()) | |
# (batch_size x context_win x vocab_size)x(vocab_size x embed_size) | |
# => (batch_size x context_win x embed_size) | |
self.X_f = tf.nn.embedding_lookup( | |
self.embed_matrix, self._X_f, name='foward_embed') | |
self.X_b = tf.nn.embedding_lookup( | |
self.embed_matrix, self._X_b, name='backward_embed') | |
def _create_net_loss(self): | |
''' | |
forward network | |
h1 = forward_RNN(X) | |
h2 = backward_RNN(X) | |
C = [h1; h2], # w*hid_dim x batch_size | |
yhat = softmax(w*C + b) # vocab_size x batch_size | |
''' | |
with tf.name_scope('feedforward_network'): | |
h1 = self._RNN(direction='forward') # (batch_size x hid_dim) | |
h2 = self._RNN(direction='backward') # (batch_size x hid_dim) | |
C = tf.concat([h1, h2], axis=1) # (batch_size x 2*hid_dim) | |
logits = tf.matmul(C, tf.transpose(self.w)) + self.b | |
yprob = tf.nn.softmax(logits) # (batch_size x 1) | |
with tf.name_scope('Loss'): | |
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( | |
labels=self.Y, logits=logits, name='cross_entropy')) | |
self.entropy = -tf.reduce_sum(tf.log(yprob) * self.Y, | |
axis=1, keepdims=True) | |
self.pred = tf.argmax(logits, 1, 'pred') | |
self.correct_pred = tf.equal( | |
self.pred, tf.argmax(self.Y, 1), name='correct_pred') | |
self.accuracy = tf.reduce_mean( | |
tf.cast(self.correct_pred, tf.float32), name='accuracy') | |
def _RNN(self, direction='forward'): | |
''' | |
**Notes on tf.nn.dynamic_rnn** | |
- 'x' can have shape (batch_size)x(time)x(dimension), if time_major=False or | |
(time)x(batch_size)x(dimension), if time_major=True | |
- 'outputs' can have the same shape as 'x' | |
(batch_size)x(time)x(dimension), if time_major=False or | |
(time)x(batch_size)x(dimension), if time_major=True | |
- 'states' is the final state, determined by batch and hidden_dim | |
''' | |
if direction is 'forward': | |
with tf.variable_scope('forward_RNN'): | |
cell = tf.contrib.rnn.BasicRNNCell( | |
self.hp.HID_SIZE) # Make RNNCell | |
outputs, states = tf.nn.dynamic_rnn( | |
cell, self.X_f, time_major=True, dtype=tf.float32) | |
elif direction is 'backward': | |
with tf.variable_scope('backward_RNN'): | |
cell = tf.contrib.rnn.BasicRNNCell( | |
self.hp.HID_SIZE) # Make RNNCell | |
outputs, states = tf.nn.dynamic_rnn( | |
cell, self.X_b, time_major=True, dtype=tf.float32) | |
# with tf.variable_scope('Cell'): | |
# cell = tf.contrib.rnn.MultiRNNCell( | |
# [tf.contrib.rnn.BasicLSTMCell(n_hidden) for _ in range(4)]) | |
# outputs, states = tf.nn.dynamic_rnn( | |
# cell, x, time_major=False, dtype=tf.float32) | |
last_output = outputs[-1] # (batch_size x vocab_size) | |
last_states = states[-1] # (batch_size x hid_dim) | |
# TODO: figure out if last_output == last_states | |
return last_output # (batch_size x hid_dim) | |
def _create_optimizer(self): | |
self.optimizer = tf.train.AdamOptimizer(self.hp.LR).minimize( | |
self.loss, global_step=self.global_step) | |
def _add_summary(self): | |
tf.summary.scalar('train_loss', self.loss) | |
tf.summary.scalar('train_accuracy', self.accuracy) | |
self.merged_summary = tf.summary.merge_all() | |
def build_graph(self): | |
tf.reset_default_graph() | |
self._create_variables() | |
self._create_net_loss() | |
self._create_optimizer() | |
self._add_summary() | |
def train(self): | |
saver = tf.train.Saver(max_to_keep=10) | |
model_dir = os.path.join(self.save_dir, 'checkpoints') | |
utils.safe_mkdir(model_dir) | |
train_gen = data.gen_batch( | |
self.hp.TRAIN_CLEAN_FILE, self.wdict, randomize=True) | |
n_all_utts = 0 | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
writer = tf.summary.FileWriter( | |
os.path.join(self.save_dir, 'summary'), sess.graph) | |
for istep in range(self.hp.NUM_TRAIN_STEPS): | |
try: | |
batch_f, batch_b, batch_y = next(train_gen) | |
except: | |
train_gen = data.gen_batch( | |
self.hp.TRAIN_CLEAN_FILE, self.wdict, randomize=True) | |
batch_f, batch_b, batch_y = next(train_gen) | |
n_all_utts += 1 # record how many iterated all utterences | |
pdb.set_trace() | |
_, train_loss, train_acc, summary = sess.run( | |
[self.optimizer, self.loss, self.accuracy, self.merged_summary], | |
{self._X_f: batch_f, self._X_b: batch_b, self._Y: batch_y}) | |
print(istep, self.hp.NUM_TRAIN_STEPS) | |
def main(): | |
# Set logger | |
# SAVE_DIR = './result_' + timenow() | |
SAVE_DIR = './result' | |
utils.safe_mkdir(SAVE_DIR) | |
logger = utils.set_logger(SAVE_DIR) | |
# Preprocess data | |
vocab, wdict, idict, fdict = data.preprocess(hp.TRAIN_FILE, hp.TEST_FILE, | |
hp.MIN_SENT_LEN, hp.VOCAB_SIZE, | |
hp.DATA_DIR) | |
with open(SAVE_DIR+'/dicts.pckl', 'wb') as pckl: | |
pickle.dump({'vocab': vocab, 'word2idx': wdict, | |
'idx2word': idict, 'freqdict': fdict}, pckl) | |
# Run model | |
M = Model(hp, logger, wdict, idict, SAVE_DIR) | |
M.build_graph() | |
M.train() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment