Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@michaelklachko
Created June 14, 2017 18:17
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaelklachko/540428fc112f5a6b06e842bb6a3f5e1e to your computer and use it in GitHub Desktop.
Save michaelklachko/540428fc112f5a6b06e842bb6a3f5e1e to your computer and use it in GitHub Desktop.
"""
Created on Sun Jun 11 15:22:27 2017
@author: Modification of https://github.com/spro/practical-pytorch/blob/master/char-rnn-generation/char-rnn-generation.ipynb by Michael Klachko
Changes:
- added batch support
- added multi-GPU support
- minor changes to train code
- removed Unicode support (assume input.txt is ASCII)
- added comments
"""
import string
import random
import torch
import torch.nn as nn
from torch.autograd import Variable
import time, math
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def time_since(since):
s = time.time() - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
printable = string.printable
#Input text is available here: https://sherlock-holm.es/stories/plain-text/cano.txt
text = open('SH.txt', 'r').read()
pruned_text = ''
for c in text:
if c in printable and c not in '{}[]&_':
pruned_text += c
#else: print c,
text = pruned_text
file_len = len(text)
alphabet = ''.join(sorted(set(text)))
n_chars = len(alphabet)
print "\nTraining RNN on Sherlock Holmes novels.\n"
print "\nFile length: {:d} characters\nUnique characters: {:d}".format(file_len, n_chars)
print "\nUnique characters:", alphabet
def random_chunk():
start = random.randint(0, file_len - chunk_len)
end = start + chunk_len + 1
return text[start:end]
def chunk_vector(chunk):
vector = torch.zeros(len(chunk)).long()
for i, c in enumerate(chunk):
vector[i] = alphabet.index(c) #construct ASCII vector for chunk, one number per character
return Variable(vector.cuda(), requires_grad=False)
def random_training_batch():
inputs = []
targets = []
#construct list of input vectors (chunk_len):
for b in range(batch_size):
chunk = random_chunk()
inp = chunk_vector(chunk[:-1])
target = chunk_vector(chunk[1:])
inputs.append(inp)
targets.append(target)
#construct batches from lists (chunk_len, batch_size):
#need .view to handle batch_size=1
#need .contiguous to allow .view later
inp = torch.cat(inputs, 0).view(batch_size, chunk_len).t().contiguous()
target = torch.cat(targets, 0).view(batch_size, chunk_len).t().contiguous()
return inp, target
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.encoder = nn.Embedding(input_size, hidden_size)
self.GRU = nn.GRU(hidden_size, hidden_size, n_layers, batch_first=True)
self.decoder = nn.Linear(hidden_size, output_size)
def forward(self, input, batch_size):
self.init_hidden(batch_size)
input = self.encoder(input)
output, self.hidden = self.GRU(input, self.hidden)
output = self.decoder(output) #batch_first
return output
def init_hidden(self, batch_size):
self.hidden = Variable(torch.randn(self.n_layers, batch_size, self.hidden_size).cuda())
seq_len = 1 #each character is encoded as a single integer
chunk_len = 128 #number of characters in a single text sample
batch_size = 64 #number of text samples in a batch
n_batches = 200 #size of training dataset (total number of batches)
hidden_size = 256 #width of model
n_layers = 2 #depth of model
LR = 0.005 #learning rate
net = torch.nn.DataParallel(RNN(n_chars, hidden_size, n_chars, n_layers)).cuda()
optim = torch.optim.Adam(net.parameters(), LR)
cost = nn.CrossEntropyLoss().cuda()
print "\nModel parameters:\n"
print "n_batches: {:d}\nbatch_size: {:d}\nchunk_len: {:d}\nhidden_size: {:d}\nn_layers: {:d}\nLR: {:.4f}\n".format(n_batches, batch_size, chunk_len, hidden_size, n_layers, LR)
print "\nRandom chunk of text:\n\n", random_chunk(), '\n'
def evaluate(prime_str = 'A', predict_len = 100, temp = 0.8, batch_size = 1):
prime_input = chunk_vector(prime_str)
predicted = prime_str
for i in range(len(prime_str)-1):
net(prime_input[i], batch_size)
inp = prime_input[-1]
for i in range(predict_len):
output = net(inp, batch_size)
output_dist = output.data.view(-1).div(temp).exp()
top_i = torch.multinomial(output_dist, 1)[0]
predicted_char = alphabet[top_i]
predicted += predicted_char
inp = chunk_vector(predicted_char)
return predicted
start = time.time()
training_set = []
for i in range(n_batches):
training_set.append((random_training_batch()))
i = 0
for inp, target in training_set:
net.zero_grad()
loss = 0
for c, t in zip(inp, target):
c = c.view(batch_size, seq_len)
output = net(c, batch_size)
loss += cost(output, t)
loss.backward()
optim.step()
if i % 100 == 0:
print "\n\nSample output:\n"
print evaluate('Wh', 100, 0.8), '\n'
print('[%s (%d / %d) loss: %.4f]' % (time_since(start), i, n_batches, loss.data[0] / chunk_len))
i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment