Skip to content

Instantly share code, notes, and snippets.

@albertlai431
Created February 7, 2019 22:18
Show Gist options
  • Save albertlai431/964688a10dd510f7ce6ed182a496caa2 to your computer and use it in GitHub Desktop.
Save albertlai431/964688a10dd510f7ce6ed182a496caa2 to your computer and use it in GitHub Desktop.
# Defining a method to generate the next character
def predict(net, char, h=None, top_k=None):
''' Given a character, predict the next character.
Returns the predicted character and the hidden state.
'''
# tensor inputs
x = np.array([[net.char2int[char]]])
x = one_hot_encode(x, len(net.chars))
inputs = torch.from_numpy(x)
if(train_on_gpu):
inputs = inputs.cuda()
# detach hidden state from history
h = tuple([each.data for each in h])
# get the output of the model
out, h = net(inputs, h)
# get the character probabilities
p = F.softmax(out, dim=1).data
if(train_on_gpu):
p = p.cpu() # move to cpu
# get top characters
if top_k is None:
top_ch = np.arange(len(net.chars))
else:
p, top_ch = p.topk(top_k)
top_ch = top_ch.numpy().squeeze()
# select the likely next character with some element of randomness
p = p.numpy().squeeze()
char = np.random.choice(top_ch, p=p/p.sum())
# return the encoded value of the predicted char and the hidden state
return net.int2char[char], h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment