Skip to content

Instantly share code, notes, and snippets.

@dainis-boumber
Forked from bkj/DotProdNB.py
Created April 22, 2019 05:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dainis-boumber/2163fd56eaed61a5d68c9de1e89c0ac8 to your computer and use it in GitHub Desktop.
Save dainis-boumber/2163fd56eaed61a5d68c9de1e89c0ac8 to your computer and use it in GitHub Desktop.
import os
import re
import string
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler
from fastai.nlp import texts_from_folders
from fastai.dataloader import DataLoader
# --
# Helpers
def to_numpy(x):
if isinstance(x, Variable):
return to_numpy(x.data)
return x.cpu().numpy() if x.is_cuda else x.numpy()
class DotProdNB(nn.Module):
def __init__(self, vocab_size, n_classes, r, w_adj=0.4, r_adj=10, lr=0.02, weight_decay=1e-6):
super().__init__()
# Init w
self.w = nn.Embedding(vocab_size + 1, 1, padding_idx=0)
self.w.weight.data.uniform_(-0.1,0.1)
# Init r
self.r = nn.Embedding(vocab_size + 1, n_classes)
self.r.weight.data = torch.Tensor(np.concatenate([np.zeros((1, n_classes)), r])).cuda()
self.r.weight.requires_grad = False
self.w_adj = w_adj
self.r_adj = r_adj
params = [p for p in self.parameters() if p.requires_grad]
self.opt = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
def forward(self, feat_idx, feat_cnt, sz):
w = self.w(feat_idx)
r = self.r(feat_idx)
x = ((w+self.w_adj)*r/self.r_adj).sum(1)
return F.softmax(x)
def step(self, x, y):
x = [Variable(xx).cuda() for xx in x]
y = Variable(y).cuda()
output = self(*x)
self.opt.zero_grad()
loss = F.l1_loss(output, y)
loss.backward()
self.opt.step()
return loss.data[0]
def one_hot(a, c):
return np.eye(c)[a]
def calc_r(y_i, x, y):
x = x.sign()
p = x[np.argwhere(y==y_i)[:,0]].sum(axis=0) + 1
q = x[np.argwhere(y!=y_i)[:,0]].sum(axis=0) + 1
return np.log((p / p.sum()) / (q / q.sum()))
class BOW_Dataset():
def __init__(self, X, y, max_len):
self.max_len = max_len
self.n_classes = int(y.max()) + 1
self.vocab_size = X.shape[1]
self.X = X
self.y = one_hot(y, self.n_classes)
def do_pad(self, prepend, a):
return np.array((prepend+a.tolist())[-self.max_len:])
def pad_row(self, row):
prepend = [0] * max(self.max_len - len(row.indices), 0)
return self.do_pad(prepend, row.indices+1), self.do_pad(prepend, row.data)
def __getitem__(self, i):
row = self.X.getrow(i)
ind, data = self.pad_row(row)
return ind, data, len(row.indices), self.y[i].astype(np.float32)
def __len__(self):
return len(self.X.indptr) - 1
# --
# IO
names = ['neg','pos']
text_train, y_train = texts_from_folders('data/aclImdb/train', names)
text_val, y_val = texts_from_folders('data/aclImdb/test', names)
# --
# Preprocess
max_features = 200000
max_len = 1000
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
tokenizer = lambda x: re_tok.sub(r' \1 ', x).split()
vectorizer = CountVectorizer(
ngram_range=(1,3),
tokenizer=tokenizer,
max_features=max_features
)
X_train = vectorizer.fit_transform(text_train)
X_val = vectorizer.transform(text_val)
vocab_size = X_train.shape[1]
n_classes = int(y_train.max()) + 1
dl_train = DataLoader(BOW_Dataset(X_train, y_train, max_len=max_len), batch_size=64, shuffle=True)
dl_val = DataLoader(BOW_Dataset(X_val, y_val, max_len=max_len), batch_size=64, shuffle=False)
# --
# Define model
r = np.stack([calc_r(i, X_train, y_train).A1 for i in range(n_classes)]).T
model = DotProdNB(vocab_size, n_classes, r, lr=0.01).cuda()
# --
# Train
_ = model.train()
for _ in range(2):
for (*x, y) in dl_train:
_ = model.step(x, y)
# --
# Eval
_ = model.eval()
pred, act = [], []
for *x, y in dl_val:
pred.append(model(*[Variable(xx, volatile=True).cuda() for xx in x]))
act.append(y)
pred = to_numpy(torch.cat(pred)).argmax(axis=1)
act = to_numpy(torch.cat(act)).argmax(axis=1)
(pred == act).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment