Skip to content

Instantly share code, notes, and snippets.

Created July 28, 2020 08:03
Show Gist options
  • Save prateekjoshi565/1f335f51c53f0db94b22fa4c61acef14 to your computer and use it in GitHub Desktop.
Save prateekjoshi565/1f335f51c53f0db94b22fa4c61acef14 to your computer and use it in GitHub Desktop.
# predict next token
def predict(net, tkn, h=None):
# tensor inputs
x = np.array([[token2int[tkn]]])
inputs = torch.from_numpy(x)
# push to GPU
inputs = inputs.cuda()
# detach hidden state from history
h = tuple([ for each in h])
# get the output of the model
out, h = net(inputs, h)
# get the token probabilities
p = F.softmax(out, dim=1).data
p = p.cpu()
p = p.numpy()
p = p.reshape(p.shape[1],)
# get indices of top 3 values
top_n_idx = p.argsort()[-3:][::-1]
# randomly select one of the three indices
sampled_token_index = top_n_idx[random.sample([0,1,2],1)[0]]
# return the encoded value of the predicted char and the hidden state
return int2token[sampled_token_index], h
# function to generate text
def sample(net, size, prime='it is'):
# push to GPU
# batch size is 1
h = net.init_hidden(1)
toks = prime.split()
# predict next token
for t in prime.split():
token, h = predict(net, t, h)
# predict subsequent tokens
for i in range(size-1):
token, h = predict(net, toks[-1], h)
return ' '.join(toks)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment