Skip to content

Instantly share code, notes, and snippets.

@cinjon
Created November 7, 2017 00:21
Show Gist options
  • Save cinjon/db8fb331f316d480a3bf40c992f4ea4b to your computer and use it in GitHub Desktop.
Save cinjon/db8fb331f316d480a3bf40c992f4ea4b to your computer and use it in GitHub Desktop.
def build_graph(self):
label = tf.one_hot(self.batch, 10*self._config.num_digits)
self.label = tf.argmax(label, -1)
num_digits = self._config.num_digits
num_binary_messages = self._config.num_binary_messages
# Speaker
with tf.variable_scope("A1"):
weights = tf.get_variable("embeddings", shape=(10*num_digits, self._config.embedding_size),
dtype=tf.float32, initializer=tf.orthogonal_initializer)
inputs = tf.nn.embedding_lookup(weights, self.batch)
inputs = tf.reshape(inputs, [self._batch_size, num_digits * self._config.embedding_size])
hidden_size = getattr(self._config, 'a1_hidden_size') or self._config.hidden_size
hidden = tf.contrib.layers.fully_connected(
inputs, hidden_size, scope="hidden", activation_fn=tf.nn.tanh,
# weights_initializer=tf.orthogonal_initializer,
)
# If we are using an eos penalty, then we are going to predict a third attribute per message, which is
# whether we want to zero out the remaining messages. If we predict to do it, then every message
# thereafter will be masked.
output_size = 3 if self._config.eos_penalty else 2
logits = tf.contrib.layers.fully_connected(
hidden, output_size*num_binary_messages,
activation_fn=None, scope="logits",
### This is commented out because with it included the last part of the weight matrix becomes 0.
### I can't explain that and it bears further inquiry
# weights_initializer=tf.orthogonal_initializer,
)
if self._config.eos_penalty:
# We additionally predict a fourth attribute meant to be a scalar that pushes the EOS probability.
scalar_eos = tf.contrib.layers.fully_connected(
logits, 1, activation_fn=tf.sigmoid, scope="scalar_eos",
weights_initializer=tf.orthogonal_initializer)
a1_logits = tf.reshape(logits, [self._batch_size, num_binary_messages, output_size])
tau = tf.get_variable("QTemperature", initializer=tf.constant_initializer(1.0), trainable=True, shape=())
q_y = tf.contrib.distributions.RelaxedOneHotCategorical(tau, logits=a1_logits)
y = q_y.sample()
y_hard = tf.cast(tf.one_hot(tf.argmax(y, -1), output_size), y.dtype)
if self._config.eos_penalty:
# append a zero out onto the back so that argmax doesn't use an incorrect indice.
self.pre_mask_argmax_messages = tf.argmax(y, -1)
one_hot = np.array([0]*(output_size - 1) + [1]).astype(np.float32)
concat_one_hot = tf.expand_dims(tf.expand_dims(tf.convert_to_tensor(one_hot), 0), 0)
concat_one_hot = tf.tile(concat_one_hot, tf.stack([tf.shape(y_hard)[0], 1, 1]))
concat_y_hard = tf.concat([y_hard, concat_one_hot], 1)
# we need to find the first message that's predicting a 2 and then zero out from there.
first_zeros = tf.argmax(tf.to_int32(tf.reduce_all(tf.equal(concat_y_hard, one_hot), 2)), 1)
self.mask = mask = tf.to_float(tf.sequence_mask(first_zeros, num_binary_messages + 1))
argmax_messages = ((tf.argmax(concat_y_hard, 2) + 1) * tf.to_int64(mask))[:, :num_binary_messages]
y_hard = tf.one_hot(argmax_messages, output_size)
self.argmax_messages = argmax_messages
else:
self.argmax_messages = tf.argmax(y, -1)
self.messages = messages = tf.stop_gradient(y_hard - y) + y
with tf.variable_scope("A2"):
print(messages)
projection = tf.contrib.layers.fully_connected(messages, self._config.embedding_size,
activation_fn=None, scope="embeddings")
messages = tf.reshape(projection, [self._batch_size, self._config.embedding_size * num_binary_messages])
hidden_size = getattr(self._config, 'a2_hidden_size') or self._config.hidden_size
hidden = tf.contrib.layers.fully_connected(
messages, hidden_size, scope="hidden", activation_fn=tf.nn.tanh,
# weights_initializer=tf.orthogonal_initializer,
)
a2_logits = tf.contrib.layers.fully_connected(
hidden, 10*num_digits**2, activation_fn=None, scope="a2_logits",
# weights_initializer=tf.orthogonal_initializer,
)
a2_logits = tf.reshape(a2_logits, [self._batch_size, num_digits, 10*num_digits])
# TODO: Add the entropy penalty.
# entropy = -1 * tf.reduce_sum(softmax * tf.log(softmax), 2)
# entropies = [tf.reduce_sum(entropy, 1)]
a1_vars = [v for v in tf.trainable_variables() if 'A1' in v.name and 'temperature' not in v.name]
a1_l2_norm = tf.add_n([tf.nn.l2_loss(v) for v in a1_vars])
a2_vars = [v for v in tf.trainable_variables() if 'A2' in v.name and 'temperature' not in v.name]
a2_l2_norm = tf.add_n([tf.nn.l2_loss(v) for v in a2_vars])
weight_summaries = tf.summary.merge([
tf.summary.scalar('a1_l2_norm', a1_l2_norm),
tf.summary.scalar('a2_l2_norm', a2_l2_norm),
])
losses = [
tf.nn.softmax_cross_entropy_with_logits(logits=a2_logits[:, i], labels=label[:, i])
for i in range(num_digits)
]
losses = [tf.reduce_mean(loss) for loss in losses]
total_loss = tf.add_n(losses)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment