Skip to content

Instantly share code, notes, and snippets.

@imironhead
Created December 8, 2017 15:54
Show Gist options
  • Save imironhead/f9553e164fc259fa1b1289b13607e43c to your computer and use it in GitHub Desktop.
Save imironhead/f9553e164fc259fa1b1289b13607e43c to your computer and use it in GitHub Desktop.
def build_model():
"""
"""
initializer = tf.truncated_normal_initializer(stddev=0.02)
feature_size = len(FLAGS.features)
source = \
tf.placeholder(shape=[None, 32, 28 * feature_size], dtype=tf.float32)
target = \
tf.placeholder(shape=[None, 32, 28], dtype=tf.float32)
if FLAGS.use_dropout:
dropout = tf.placeholder(shape=[], dtype=tf.float32)
else:
dropout = None
flow = source
# head weighting
if FLAGS.head_fc_num > 0:
flow = tf.reshape(flow, [-1, 32, 1, 28 * feature_size])
for i in range(FLAGS.head_fc_num):
flow = tf.contrib.layers.fully_connected(
inputs=flow,
num_outputs=FLAGS.head_fc_output_num,
activation_fn=tf.nn.relu,
weights_initializer=initializer,
scope='head_fc_{}'.format(i))
flow = tf.reshape(flow, [-1, 32, FLAGS.head_fc_output_num])
segments = tf.unstack(flow, 32, axis=1)
# rnn cell factory
def rnn_cell_factory(num_proj):
return tf.contrib.rnn.LSTMCell(
FLAGS.rnn_state_size,
num_proj=num_proj,
initializer=initializer,
use_peepholes=True,
forget_bias=FLAGS.forget_bias,
state_is_tuple=True)
cells = []
for i in range(FLAGS.rnn_num):
cell = rnn_cell_factory(FLAGS.rnn_output_num)
if FLAGS.rnn_resnet:
cell = tf.contrib.rnn.ResidualWrapper(cell)
cells.append(cell)
# build rnn
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
if FLAGS.use_dropout:
rnn_cell = tf.contrib.rnn.DropoutWrapper(
rnn_cell, dropout, dropout, dropout)
# state
head_state = None
if FLAGS.use_variable_initial_state:
zero_state = rnn_cell.zero_state(tf.shape(source)[0], tf.float32)
head_state = []
for i, zs in enumerate(zero_state):
vc = tf.get_variable(
'vc_{}'.format(i),
(1, zs.c.shape[1]),
initializer=initializer)
vc = tf.tile(vc, [tf.shape(source)[0], 1])
vh = tf.get_variable(
'vh_{}'.format(i),
(1, zs.h.shape[1]),
initializer=initializer)
vh = tf.tile(vh, [tf.shape(source)[0], 1])
head_state.append(tf.contrib.rnn.LSTMStateTuple(vc, vh))
segments, last_state = tf.contrib.rnn.static_rnn(
rnn_cell, segments, head_state, dtype=tf.float32)
result = tf.concat(segments, axis=1)
# tail weighting
if FLAGS.tail_fc_num > 0:
flow = tf.reshape(result, [-1, 32, 1, FLAGS.rnn_output_num])
for i in range(FLAGS.tail_fc_num):
flow = tf.contrib.layers.fully_connected(
inputs=flow,
num_outputs=FLAGS.tail_fc_output_num,
activation_fn=tf.nn.relu,
weights_initializer=initializer,
scope='tail_fc_{}'.format(i))
result = tf.reshape(flow, [-1, 32 * FLAGS.tail_fc_output_num])
# to 28D
result = tf.reshape(
result, [-1, 32, 1, np.prod(result.shape[1:]) / 32])
result = tf.contrib.layers.fully_connected(
inputs=result,
num_outputs=28,
activation_fn=None,
weights_initializer=initializer,
scope='final_fc_{}'.format(i))
guess = tf.reshape(result, (-1, 32 * 28))
truth = tf.reshape(target, (-1, 32 * 28))
if FLAGS.use_sequence_loss:
range_truth, range_guess = truth, guess
else:
range_truth, range_guess = truth[:, -28:], guess[:, -28:]
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=range_truth, logits=range_guess)
loss = tf.reduce_mean(loss)
trainer = tf.train \
.AdamOptimizer(learning_rate=0.0007, beta1=0.5) \
.minimize(loss)
guess = tf.reshape(guess, (-1, 32, 28))[:, -1, :]
guess = tf.nn.sigmoid(guess)
print guess.shape
return {
'source': source,
'target': target,
'guess': guess,
'loss': loss,
'trainer': trainer,
'dropout': dropout,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment