Skip to content

Instantly share code, notes, and snippets.

@yunjey
Last active July 5, 2017 05:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yunjey/5353ba8c770d94e1f170c212a986f63d to your computer and use it in GitHub Desktop.
Save yunjey/5353ba8c770d94e1f170c212a986f63d to your computer and use it in GitHub Desktop.
import torch
import numpy as np
# Hyper-parameters
vocab_size = 10
batch_size = 3
seq_length = 4
# Generate random indices of range [0, vocab_size)
word_indices = torch.from_numpy(np.random.randint(low=0, high=vocab_size, size=(batch_size, seq_length))).view(-1)
# Generate batch indices [0, ..., 0, 1, ...1, (batch_size-1), .... (batch_size-1)] (total length: seq_length * batch_size )
batch_indices = torch.arange(start=0, end=batch_size).long()
batch_indices = batch_indices.expand(seq_length, batch_size).transpose(1, 0).contiguous().view(-1)
# Generate random weigths of range [0, 1)
attn_weights = torch.rand(batch_size, seq_length)
# result
probs = torch.zeros(batch_size, vocab_size)
idx_repeat = torch.arange(start=0, end=seq_length).repeat(batch_size).long()
probs[batch_indices, word_indices] += attn_weights[batch_indices, idx_repeat]
print("Attention weights: ", attn_weights)
'''
0.3650 0.3275 0.3670 0.5113
0.9162 0.1386 0.3756 0.0706
0.3357 0.9851 0.5998 0.8086
'''
print("Word indices: ", word_indices.view(batch_size, seq_length))
'''
1 6 2 8
7 3 2 4
0 5 1 4
'''
print ("Total probabilities: ", probs)
'''
0.0000 0.3650 0.3670 0.0000 0.0000 0.0000 0.3275 0.0000 0.5113 0.0000
0.0000 0.0000 0.3756 0.1386 0.0706 0.0000 0.0000 0.9162 0.0000 0.0000
0.3357 0.5998 0.0000 0.0000 0.8086 0.9851 0.0000 0.0000 0.0000 0.0000
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment