-
-
Save aria42/2ff21b8c567d12d979a64f3a37fd029d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.optim as optim | |
import itertools | |
import time | |
import argparse | |
class EmbedDict(object): | |
def __init__(self, path): | |
self.word_to_idx = {} | |
vecs = [] | |
for line in open(path, 'r'): | |
word, *dims = line.split(' ') | |
dims = torch.FloatTensor([float(x) for x in dims]) | |
self.word_to_idx[word] = len(self.word_to_idx) | |
vecs.append(dims) | |
self.embeddings = torch.stack(vecs) | |
def convert(self, sentence): | |
idxs = [self.word_to_idx.get(w, -1) for w in sentence] | |
idxs = [i for i in idxs if i >= 0] | |
if len(idxs) == 0: | |
return None | |
return autograd.Variable(torch.LongTensor(idxs)) | |
class LSTMTagger(nn.Module): | |
def __init__(self, embed_dict, hidden_dim, num_classes): | |
super(LSTMTagger, self).__init__() | |
self.hidden_dim = hidden_dim | |
self.embed_dict = embed_dict | |
vocab_size, embedding_dim = embed_dict.embeddings.size() | |
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) | |
# freeze embeddings | |
self.word_embeddings.weight.data.copy_(embed_dict.embeddings) | |
self.word_embeddings.weight.requires_grad = False | |
# single layer bidirectional lstm | |
self.num_dirs = 2 | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True) | |
# The linear layer that maps from hidden state space to tag space | |
self.hidden2tag = nn.Linear(self.num_dirs * hidden_dim, num_classes) | |
self.dropout = nn.Dropout(0.5) | |
self.hidden = self.init_hidden() | |
def init_hidden(self): | |
return (autograd.Variable(torch.zeros(self.num_dirs, 1, self.hidden_dim)), | |
autograd.Variable(torch.zeros(self.num_dirs, 1, self.hidden_dim))) | |
def forward(self, sentence): | |
embeds = self.word_embeddings(sentence) | |
flatten_input = embeds.view(len(sentence), 1, -1) | |
lstm_out, self.hidden = self.lstm(flatten_input, self.hidden) | |
hidden_state = lstm_out[-1].view(1, self.num_dirs * self.hidden_dim) | |
hidden_state = self.dropout(hidden_state) | |
return self.hidden2tag(hidden_state) | |
def load_data(path): | |
for line in open(path, 'rb'): | |
try: | |
label, *sent = line.decode('utf-8').split(' ') | |
yield sent, int(label) | |
except: | |
continue | |
parser = argparse.ArgumentParser() | |
parser.add_argument("embed_file", help="path to embedding") | |
parser.add_argument("train_file", help="path to train file") | |
parser.add_argument("test_file", help="path to test file") | |
parser.add_argument("--num_data", help="number of train/test examples", type=int, default=2000) | |
parser.add_argument("--batch_size", help="size of batch", type=int, default=32) | |
parser.add_argument("--lstm_size", help="size of LSTM hidden state", type=int, default=25) | |
args = parser.parse_args() | |
n = args.num_data | |
embed_dict = EmbedDict(args.embed_file) | |
train_data = list(itertools.islice(load_data(args.train_file), n)) | |
test_data = list(itertools.islice(load_data(args.test_file), n)) | |
model = LSTMTagger(embed_dict, args.lstm_size, 2) | |
loss_function = nn.CrossEntropyLoss() | |
learned_params = [param for pname, param in model.named_parameters() if 'word_embeddings' not in pname] | |
optimizer = optim.Adadelta(learned_params) | |
def chunks(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
remainder = len(lst) % n | |
if remainder > 0: | |
yield lst[-remainder:] | |
def eval(model, data): | |
model.eval() | |
num_correct, num_total = 0, 0 | |
for sentence, label in data: | |
model.hidden = model.init_hidden() | |
sentence = embed_dict.convert(sentence) | |
if sentence is None: | |
continue | |
tag_scores = model(sentence) | |
if tag_scores is None: | |
continue | |
predict = tag_scores.data.numpy().argmax() | |
if predict == label: | |
num_correct += 1 | |
num_total += 1 | |
return float(num_correct)/float(num_total) | |
for epoch in range(100): | |
total_loss = torch.Tensor([0]) | |
start_iter = int(round(time.time() * 1000)) | |
model.train() | |
for batch in chunks(train_data, args.batch_size): | |
model.zero_grad() | |
for sentence, label in batch: | |
model.hidden = model.init_hidden() | |
sentence = embed_dict.convert(sentence) | |
if sentence is None: | |
continue | |
tag_scores = model(sentence) | |
label_var = autograd.Variable(torch.LongTensor([label])) | |
loss = loss_function(tag_scores, label_var) | |
total_loss += loss.data | |
loss.backward() | |
optimizer.step() | |
train_acc = eval(model, train_data) | |
test_acc = eval(model, test_data) | |
print('Train Accuracy: ', train_acc) | |
print('Test Accuracy: ', test_acc) | |
num_millis = int(round(time.time() * 1000)) - start_iter | |
print('End of epoch {}: {} [{}ms]'.format(epoch, total_loss[0], num_millis)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment