Skip to content

Instantly share code, notes, and snippets.

@dhpollack
Created September 15, 2017 13:18
Show Gist options
  • Save dhpollack/37077cc109fb1af15cb41dd617c47c30 to your computer and use it in GitHub Desktop.
Save dhpollack/37077cc109fb1af15cb41dd617c47c30 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
from tqdm import tnrange, tqdm_notebook, tqdm
"""My attempt at Karpathy's char-rnn from the unreasonableness of RNNs post
currently the loss goes down, but it spits out gibberish
"""
class SimpleGRU(nn.Module):
def __init__(self, vocab_size, emb_size, hid_size, batch_size, seq_len, n_layers=1):
super(SimpleGRU, self).__init__()
self.vocab_size = vocab_size
self.emb_size = emb_size
self.hid_size = hid_size
self.n_layers = n_layers
self.batch_size = batch_size
self.seq_len = seq_len
self.emb = nn.Embedding(vocab_size, emb_size)
self.gru = nn.GRU(emb_size, hid_size, batch_first=True)
self.fc1 = nn.Linear(seq_len * hid_size, vocab_size)
self.relu = nn.ReLU()
self.selu = nn.SELU()
self.logsoftmax = nn.LogSoftmax()
self.batchnorm = nn.BatchNorm1d(emb_size-1)
def forward(self, input, hidden):
self.sizes = []
self.sizes.append((inputs.size(), hidden.size()))
x = self.emb(input)
x = self.batchnorm(x)
self.sizes.append(x.size())
x, hidden = self.gru(x, hidden)
x = x.contiguous().view(self.batch_size, -1)
x = self.selu(self.fc1(x))
self.sizes.append((x.size(), hidden.size()))
x = self.logsoftmax(x)
self.sizes.append(x.size())
return x, hidden
class CharDataset(data.Dataset):
def __init__(self, data, seq_len):
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
inp_seq = self.data[index:(index+self.seq_len-1)]
tgt_seq = torch.Tensor([self.data[index+self.seq_len]]).type(self.data.type())
return inp_seq, tgt_seq
def __len__(self):
return len(self.data) - self.seq_len
seq_length = 25
batch_size = 250
emb_size = 25
hid_size = 100
n_layers = 3
with open("/home/david/Programming/data/project_gutenberg/tiny-shakespeare.txt", "r") as f:
text_raw = [c for l in f.readlines() for c in l]
charset = sorted(list(set(text_raw)))
c2i = {c: i for i, c in enumerate(charset)}
i2c = {i: c for c, i in c2i.items()}
text_idx = [c2i[c] for c in text_raw]
tqdm.write("{} {}".format(len(text_idx), len(text_raw)))
inputs = torch.Tensor(text_idx).long()
tqdm.write("{}".format(inputs.size()))
ds = CharDataset(inputs, seq_length)
dl = data.DataLoader(ds, batch_size=batch_size, drop_last=True)
tqdm.write("{}".format(len(dl)))
vocab_size = len(charset)
num_batches = len(dl)
epochs = 20
lr = 0.003
momo = 0.9
model = SimpleGRU(vocab_size, emb_size, hid_size, batch_size, seq_length-1, n_layers)
model.load_state_dict(torch.load("checkpoints/char-rnn/model_19.pt"))
criterion = nn.NLLLoss()
#optimizer = optim.Adam(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momo)
tqdm.write("{}".format(model))
#epoch_bar = tqdm(range(epochs), desc="epochs", disable=True)
train = False
if train:
for epoch in range(epochs):
running_loss = 0
batch_bar = tqdm(enumerate(dl), desc="batches")
for i, (mb, tgts) in batch_bar:
h = Variable(torch.zeros(n_layers,batch_size, hid_size))
tgts.squeeze_()
model.train()
model.zero_grad()
mb, tgts = Variable(mb), Variable(tgts)
out, h = model(mb, h)
loss = criterion(out, tgts)
loss.backward()
optimizer.step()
h.detach_()
running_loss += loss.data[0]
if i % 25 == 0 and i > 0 or i == num_batches - 1:
batch_bar.set_postfix(ave_loss=running_loss / (i+1), last_loss=loss.data[0])
pass
#epoch_bar.set_postfix(prevloss=(running_loss / num_batches))
torch.save(model.state_dict(), "model_{}.pt".format(epoch))
tqdm.write("epoch {}".format(epoch+1))
else:
batch_bar = tqdm(enumerate(dl), desc="batches")
model.eval()
pred = []
for i, (mb, tgts) in batch_bar:
h = Variable(torch.zeros(n_layers,batch_size, hid_size))
tgts.squeeze_()
mb, tgts = Variable(mb), Variable(tgts)
out, h = model(mb, h)
pred.append(out.data.max(1)[1])
if i == 100:
break
pred = torch.cat(pred)
pred_c = [i2c[i] for i in pred]
print("".join(pred_c[2000:2350]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment