Skip to content

Instantly share code, notes, and snippets.

@cinjon
Last active November 6, 2017 23:30
Show Gist options
  • Save cinjon/83548630c2e2f1fc382f109677bc0af5 to your computer and use it in GitHub Desktop.
Save cinjon/83548630c2e2f1fc382f109677bc0af5 to your computer and use it in GitHub Desktop.
q_y = tf.contrib.distributions.RelaxedOneHotCategorical(tau, logits=a1_logits)
y = q_y.sample()
y_hard = tf.cast(tf.one_hot(tf.argmax(y, -1), output_size), y.dtype)
# append a zero out onto the back so that argmax doesn't use an incorrect indice.
one_hot = np.array([0]*(output_size - 1) + [1]).astype(np.float32)
concat_one_hot = tf.expand_dims(tf.expand_dims(tf.convert_to_tensor(one_hot), 0), 0)
concat_one_hot = tf.tile(concat_one_hot, tf.stack([tf.shape(y_hard)[0], 1, 1]))
concat_y_hard = tf.concat([y_hard, concat_one_hot], 1)
# we need to find the first message that's predicting a 2 and then zero out from there.
first_zeros = tf.argmax(tf.to_int32(tf.reduce_all(tf.equal(concat_y_hard, one_hot), 2)), 1)
mask = tf.to_float(tf.sequence_mask(first_zeros, num_binary_messages + 1))
argmax_messages = ((tf.argmax(concat_y_hard, 2) + 1) * tf.to_int64(mask))[:, :num_binary_messages]
y_hard = tf.one_hot(argmax_messages, output_size)
messages = tf.stop_gradient(y_hard - y) + y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment