Skip to content

Instantly share code, notes, and snippets.

@notcome
Created December 9, 2018 19:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save notcome/508968ac9487879e5235b1bd5acfa21b to your computer and use it in GitHub Desktop.
Save notcome/508968ac9487879e5235b1bd5acfa21b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
vocabSize = 30000
hiddenSize = 128
batchSize = 16
seqSize = 128
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.embedding = nn.Embedding(vocabSize, hiddenSize)
self.linear = nn.Linear(hiddenSize, hiddenSize)
self.layerNorm = nn.LayerNorm(hiddenSize)
self.output = nn.Linear(hiddenSize, vocabSize, bias = False)
self.bias = nn.Parameter(torch.zeros(vocabSize))
self.output.weight = self.embedding.weight
def forward(self, input):
hiddens = self.embedding(input)
hiddens = self.linear(hiddens)
hiddens = F.relu(hiddens)
hiddens = F.dropout(hiddens, p = 0.1, training = self.training)
hiddens = self.layerNorm(hiddens)
preds = self.output(hiddens)
return preds + self.bias
def loss(self, target):
copy = self.forward(target)
count = batchSize * seqSize
return F.cross_entropy(copy.view(count, -1), target.view(count))
net = Net()
device = torch.device('cuda')
net.to(device)
def sample():
return torch.randint(0, vocabSize, (batchSize, seqSize), dtype = torch.long, device = device)
optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001,
betas = (0.9, 0.999), weight_decay = 0.01)
def stepBatch():
optimizer.zero_grad()
loss = net.loss(sample())
loss.backward()
optimizer.step()
return loss.data.item()
def eval(testSize = 16):
wrong = 0
total = 0
for i in range(testSize):
labels = sample()
output = net(labels)
preds = torch.argmax(output, dim = -1)
wrong += len(torch.nonzero(labels - preds))
total += batchSize * seqSize
acc = float(total - wrong) / float(total) * 100.0
print('Acc = %.2f%%' % acc)
def train():
sum = 0.0
for i in range(1, 100001):
sum += stepBatch()
if i % 100 == 0:
avg = sum / 100.0
sum = 0.0
print('[%6d] avg loss = %.4f' % (i, avg))
eval()
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment