-
-
Save MInner/af316efe081dba2fc219391e12aa24ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from math import log10, ceil | |
def last_relevant(output, length): | |
batch_size = tf.shape(output)[0] | |
max_length = int(output.get_shape()[1]) | |
output_size = int(output.get_shape()[2]) | |
index = tf.range(0, batch_size) * max_length + (length - 1) | |
flat = tf.reshape(output, [-1, output_size]) | |
relevant = tf.gather(flat, index) | |
return relevant | |
def build_model(config): | |
X = tf.placeholder(tf.int32, [None, config['seq_len'], 12], name='X') | |
y = tf.placeholder(tf.int32, [None, 5], name='y') | |
use_dropout = tf.placeholder_with_default(tf.constant(1.0), [], name='use_dropout') | |
dropout_rate_const = tf.constant(config['dropout_keep_rate']) | |
dropout_keep_rate = dropout_rate_const*use_dropout + (1-use_dropout) | |
# X_mask = X[:, :, 0] # [batch_n, seq_len] # broken | |
X_bools = X[:, :, 1:4] # [batch_n, seq_len, 3] | |
X_cmd = X[:, :, 4:8] # [batch_n, seq_len, 4] | |
X_label = X[:, :, 8] # [batch_n, seq_len] | |
X_token = X[:, :, 9] # [batch_n, seq_len] | |
X_word = X[:, :, 11] | |
y_bool = y[:, 0:3] # [batch_n, 3] | |
y_label = y[:, 3] # [batch_n, ] | |
y_token = y[:, 4] # [batch_n, ] | |
batch_size = tf.shape(X)[0] | |
terminal_mask = tf.cast(y[:, 2], tf.float32) | |
data_lengths = tf.reduce_sum(tf.cast(tf.not_equal(X_token, 0), tf.int32), -1) | |
word_data_len = tf.reduce_sum(X[:, :, 10], -1) | |
# [batch_n, seq_len, him_dim] | |
with tf.variable_scope("embeddings", initializer=tf.contrib.layers.xavier_initializer()): | |
E_labl = tf.get_variable('E_labl', (config['label_voc_size'], config['label_hid_dim'])) | |
X_embedded_labels = tf.nn.embedding_lookup(E_labl, X_label-1) | |
E_tokn = tf.get_variable('E_tokn', (config['token_voc_size'], config['token_hid_dim'])) | |
X_embedded_tokens = tf.nn.embedding_lookup(E_tokn, X_token-1) | |
X_embedded_words = tf.nn.embedding_lookup(E_tokn, X_word-1) | |
X_path_total = tf.concat(2, [tf.cast(X_bools, tf.float32), | |
tf.cast(X_cmd, tf.float32), | |
X_embedded_labels, X_embedded_tokens]) | |
X_word_total = X_embedded_words | |
with tf.variable_scope('path_rnn'): | |
output_path, _ = tf.nn.dynamic_rnn( | |
tf.nn.rnn_cell.DropoutWrapper( | |
tf.nn.rnn_cell.GRUCell(config['hid_dim']), | |
input_keep_prob=dropout_keep_rate, | |
output_keep_prob=dropout_keep_rate | |
), | |
inputs=X_path_total, | |
dtype=tf.float32, | |
sequence_length=data_lengths, | |
) | |
with tf.variable_scope('word_rnn'): | |
output_word, _ = tf.nn.dynamic_rnn( | |
tf.nn.rnn_cell.DropoutWrapper( | |
tf.nn.rnn_cell.GRUCell(config['word_hid_dim']), | |
input_keep_prob=dropout_keep_rate, | |
output_keep_prob=dropout_keep_rate | |
), | |
inputs=X_word_total, | |
dtype=tf.float32, | |
sequence_length=word_data_len, | |
) | |
path_h = tf.tanh(last_relevant(output_path, data_lengths)) # [batch_size, hid_dim] | |
word_h = tf.tanh(last_relevant(output_word, word_data_len)) # [batch_size, word_hid_dim] | |
total_h = tf.concat(1, [path_h, word_h]) # [batch_size, hid_dim + word_hid_dim] | |
## rnn output -> logits | |
fcc = tf.contrib.layers.fully_connected | |
argkw = {'inputs': total_h, 'activation_fn': tf.tanh} | |
repr_bool = fcc(num_outputs=3, scope='repr2bool', **argkw) | |
repr_labl = fcc(num_outputs=config['label_hid_dim'], scope='repr2labl', **argkw) | |
repr_tokn = fcc(num_outputs=config['token_hid_dim'], scope='repr2tokn', **argkw) | |
with tf.variable_scope("embeddings", reuse=True): | |
logits_labl = tf.matmul(repr_labl, tf.transpose(tf.get_variable('E_labl'))) | |
logits_tokn = tf.matmul(repr_tokn, tf.transpose(tf.get_variable('E_tokn'))) | |
## logits -> losses | |
_y = tf.cast(y_bool, tf.float32) | |
cr_ent_bool = tf.nn.sigmoid_cross_entropy_with_logits(repr_bool, _y) | |
loss_bool = tf.reduce_mean(cr_ent_bool) | |
cr_ent_labl = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_labl, y_label-1) | |
loss_labl = tf.reduce_mean( cr_ent_labl ) | |
cr_ent_tokn = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_tokn, y_token-1) | |
loss_tokn = tf.reduce_mean( cr_ent_tokn ) | |
term_mask_normalizer = tf.reduce_sum(terminal_mask) | |
loss_tokn_term = tf.reduce_sum( cr_ent_tokn * terminal_mask ) / term_mask_normalizer | |
## loss | |
loss = loss_bool + loss_labl + loss_tokn_term | |
## errs | |
pred_bool = tf.cast(tf.greater(repr_bool, 0.5), tf.int32) | |
err_bool_mean = tf.reduce_mean(tf.cast(tf.not_equal(pred_bool, y_bool), tf.float32)) | |
pred_labl = tf.cast(tf.argmax(logits_labl, 1), tf.int32) | |
err_labl_mean = tf.reduce_mean(tf.cast(tf.not_equal(pred_labl, y_label-1), tf.float32)) | |
pred_tokn = tf.cast(tf.argmax(logits_tokn, 1), tf.int32) | |
err_tokn = tf.cast(tf.not_equal(pred_tokn, y_token-1), tf.float32) | |
err_tokn_mean = tf.reduce_mean(err_tokn) | |
err_tokn_term = tf.reduce_sum(tf.cast(terminal_mask, tf.float32)*err_tokn) | |
err_term_mean = err_tokn_term / tf.reduce_sum(tf.cast(terminal_mask, tf.float32)) | |
prob_bool = tf.sigmoid(repr_bool) | |
prob_labl = tf.nn.softmax(logits_labl) | |
prob_tokn = tf.nn.softmax(logits_tokn) | |
export_map = { | |
'inputs': ['X', 'y', 'use_dropout'], | |
'outputs': [ | |
'pred_bool', 'pred_labl', 'pred_tokn', | |
'prob_bool', 'prob_labl', 'prob_tokn', | |
], | |
'stat': [ | |
'loss', 'loss_bool', 'loss_tokn_term', | |
'err_bool_mean', 'err_labl_mean', | |
'err_tokn_mean', 'err_term_mean', | |
], | |
'other': [ | |
'cr_ent_tokn', 'terminal_mask', | |
'logits_labl' | |
] | |
} | |
return record.from_local('Model', locals(), export_map) | |
graph = tf.Graph() | |
with graph.as_default(), tf.device('/cpu:0'): | |
gd = tf.train.AdamOptimizer(config['model']['learning_rate']) | |
with tf.device(config['exec']['device_id']): | |
seed = config['exec']['seed'] if 'seed' in config['exec'] else 1 | |
tf.set_random_seed(seed) | |
model = build_model({**config['model'], **config['data']['specs']}) | |
train_op = gd.minimize(model.stat.loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment