Skip to content

Instantly share code, notes, and snippets.


aria42/ Secret

Created Nov 6, 2017
What would you like to do?
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import itertools
import time
import argparse
class EmbedDict(object):
def __init__(self, path):
self.word_to_idx = {}
vecs = []
for line in open(path, 'r'):
word, *dims = line.split(' ')
dims = torch.FloatTensor([float(x) for x in dims])
self.word_to_idx[word] = len(self.word_to_idx)
self.embeddings = torch.stack(vecs)
def convert(self, sentence):
idxs = [self.word_to_idx.get(w, -1) for w in sentence]
idxs = [i for i in idxs if i >= 0]
if len(idxs) == 0:
return None
return autograd.Variable(torch.LongTensor(idxs))
class LSTMTagger(nn.Module):
def __init__(self, embed_dict, hidden_dim, num_classes):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.embed_dict = embed_dict
vocab_size, embedding_dim = embed_dict.embeddings.size()
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
# freeze embeddings
self.word_embeddings.weight.requires_grad = False
# single layer bidirectional lstm
self.num_dirs = 2
self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(self.num_dirs * hidden_dim, num_classes)
self.dropout = nn.Dropout(0.5)
self.hidden = self.init_hidden()
def init_hidden(self):
return (autograd.Variable(torch.zeros(self.num_dirs, 1, self.hidden_dim)),
autograd.Variable(torch.zeros(self.num_dirs, 1, self.hidden_dim)))
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
flatten_input = embeds.view(len(sentence), 1, -1)
lstm_out, self.hidden = self.lstm(flatten_input, self.hidden)
hidden_state = lstm_out[-1].view(1, self.num_dirs * self.hidden_dim)
hidden_state = self.dropout(hidden_state)
return self.hidden2tag(hidden_state)
def load_data(path):
for line in open(path, 'rb'):
label, *sent = line.decode('utf-8').split(' ')
yield sent, int(label)
parser = argparse.ArgumentParser()
parser.add_argument("embed_file", help="path to embedding")
parser.add_argument("train_file", help="path to train file")
parser.add_argument("test_file", help="path to test file")
parser.add_argument("--num_data", help="number of train/test examples", type=int, default=2000)
parser.add_argument("--batch_size", help="size of batch", type=int, default=32)
parser.add_argument("--lstm_size", help="size of LSTM hidden state", type=int, default=25)
args = parser.parse_args()
n = args.num_data
embed_dict = EmbedDict(args.embed_file)
train_data = list(itertools.islice(load_data(args.train_file), n))
test_data = list(itertools.islice(load_data(args.test_file), n))
model = LSTMTagger(embed_dict, args.lstm_size, 2)
loss_function = nn.CrossEntropyLoss()
learned_params = [param for pname, param in model.named_parameters() if 'word_embeddings' not in pname]
optimizer = optim.Adadelta(learned_params)
def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]
remainder = len(lst) % n
if remainder > 0:
yield lst[-remainder:]
def eval(model, data):
num_correct, num_total = 0, 0
for sentence, label in data:
model.hidden = model.init_hidden()
sentence = embed_dict.convert(sentence)
if sentence is None:
tag_scores = model(sentence)
if tag_scores is None:
predict =
if predict == label:
num_correct += 1
num_total += 1
return float(num_correct)/float(num_total)
for epoch in range(100):
total_loss = torch.Tensor([0])
start_iter = int(round(time.time() * 1000))
for batch in chunks(train_data, args.batch_size):
for sentence, label in batch:
model.hidden = model.init_hidden()
sentence = embed_dict.convert(sentence)
if sentence is None:
tag_scores = model(sentence)
label_var = autograd.Variable(torch.LongTensor([label]))
loss = loss_function(tag_scores, label_var)
total_loss +=
train_acc = eval(model, train_data)
test_acc = eval(model, test_data)
print('Train Accuracy: ', train_acc)
print('Test Accuracy: ', test_acc)
num_millis = int(round(time.time() * 1000)) - start_iter
print('End of epoch {}: {} [{}ms]'.format(epoch, total_loss[0], num_millis))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment