Skip to content

Instantly share code, notes, and snippets.

@davidmcclure
Created February 13, 2018 21:44
Show Gist options
  • Save davidmcclure/ab7359ff6ca44e56afebfd6c1a279ba0 to your computer and use it in GitHub Desktop.
Save davidmcclure/ab7359ff6ca44e56afebfd6c1a279ba0 to your computer and use it in GitHub Desktop.
import torch
import csv
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from cached_property import cached_property
from tqdm import tqdm
from sklearn import metrics
from cs287.hw1.data import TEXT, train_iter, val_iter, test_iter
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = nn.Embedding(
TEXT.vocab.vectors.shape[0],
TEXT.vocab.vectors.shape[1]
)
self.embeddings.weight.data.copy_(TEXT.vocab.vectors)
self.convs1 = nn.ModuleList([
nn.Conv2d(1, 100, (n, TEXT.vocab.vectors.shape[1]))
for n in (3, 4, 5)
])
self.dropout = nn.Dropout()
self.out = nn.Linear(300, 2)
def forward(self, x, lengths):
embeds = self.embeddings(x)
x = embeds.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
x = torch.cat(x, 1)
x = self.dropout(x)
x = self.out(x)
return F.log_softmax(x, dim=1)
class CNNModel:
@cached_property
def clf(self):
return Classifier()
def train(self, train_iter, val_iter, epochs=10, lr=1e-4):
"""Train for N epochs.
"""
self.clf.train(True)
optimizer = torch.optim.Adam(self.clf.parameters(), lr=lr)
loss_func = nn.NLLLoss()
for epoch in range(epochs):
print(f'\nEpoch {epoch}')
epoch_loss = 0
for batch in tqdm(train_iter):
optimizer.zero_grad()
y_pred = self.clf(*batch.text)
loss = loss_func(y_pred, batch.label-1)
loss.backward()
optimizer.step()
epoch_loss += loss.data[0]
print('Loss: %f' % (epoch_loss / len(train_iter)))
self.print_metrics(val_iter, train=True)
def predict(self, x_iter, train=False):
"""Predict test cases.
"""
self.clf.train(False)
y_true, y_pred = [], []
for batch in x_iter:
y_true += list(batch.label.data)
preds = self.clf(*batch.text)
_, argmax = preds.max(1)
y_pred += list(argmax.data + 1)
self.clf.train(train)
return y_true, y_pred
def print_metrics(self, *args, **kwargs):
"""Print accuracy + f1.
"""
y_true, y_pred = self.predict(*args, **kwargs)
print('Accuracy: %f' % metrics.accuracy_score(y_true, y_pred))
print('F1: %f' % metrics.f1_score(y_true, y_pred))
def write_kaggle_submission(self, test_iter, path):
"""Write predictions for Kaggle.
"""
_, y_pred = self.predict(test_iter)
with open(path, 'w') as fh:
writer = csv.DictWriter(fh, fieldnames=('Id', 'Cat'))
writer.writeheader()
for i, cat in enumerate(y_pred):
writer.writerow(dict(Id=i, Cat=cat))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment