Skip to content

Instantly share code, notes, and snippets.

@albertlai431
Created February 7, 2019 22:19
Show Gist options
  • Save albertlai431/c56c72e6e78f7c70718cf18d55cff0b8 to your computer and use it in GitHub Desktop.
Save albertlai431/c56c72e6e78f7c70718cf18d55cff0b8 to your computer and use it in GitHub Desktop.
def sample(net, size, prime='The', top_k=None):
if(train_on_gpu):
net.cuda()
else:
net.cpu()
net.eval() # eval mode
# First off, run through the prime characters
chars = [ch for ch in prime]
h = net.init_hidden(1)
for ch in prime:
char, h = predict(net, ch, h, top_k=top_k)
chars.append(char)
# Now pass in the previous character and get a new one
for ii in range(size):
char, h = predict(net, chars[-1], h, top_k=top_k)
chars.append(char)
return ''.join(chars)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment