Skip to content

Instantly share code, notes, and snippets.

@cinjon
Last active October 28, 2019 23:24
Show Gist options
  • Save cinjon/88a9a368cd253b96c40bc5676fd6e8e2 to your computer and use it in GitHub Desktop.
Save cinjon/88a9a368cd253b96c40bc5676fd6e8e2 to your computer and use it in GitHub Desktop.
...
batch_size = inputs.shape[0]
input_embs = sender_embedding(inputs)
inputs = input_embs.view(batch_size, num_digits * embedding_size_sender)
hx = torch.zeros(batch_size, num_lstm_sender)
cx = torch.zeros(batch_size, num_lstm_sender)
for num in range(num_binary_messages):
hx, cx = sender_cell(inputs, (hx, cx))
output = sender_project(hx)
pre_logits = sender_out(output)
sample = gumbel_softmax(pre_logits, temperature[num])
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment