Created
November 14, 2017 15:04
-
-
Save rasoolims/43dbe2fc58187031384616a7e3b8bfba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dynet import * | |
from depModel import DepModel | |
import random, time | |
import numpy as np | |
init() | |
class NNModel(DepModel): | |
def __init__(self, words, pos_tags, labels, options): | |
self.model = Model() | |
self.options = options | |
self.trainer = AdamTrainer(self.model, options.lr, options.beta1, options.beta2) | |
self.word_vocab = {w: i + 1 for i, w in enumerate(words)} | |
self.pos_vocab = {t: i + 1 for i, t in enumerate(pos_tags)} | |
self.dep_vocab = {l: i + 1 for i, l in enumerate(labels)} # zero for <null> | |
self.rdep = ['<null>'] + labels | |
self.rlabel = list() | |
self.rlabel.append('SHIFT') | |
for d in labels: | |
self.rlabel.append('RIGHT-ARC:' + d) | |
self.rlabel.append('LEFT-ARC:' + d) | |
self.labels = {l: i for i, l in enumerate(self.rlabel)} | |
self.n_word_feat = 20 | |
self.n_label_feat = 12 | |
self.we = self.model.add_lookup_parameters((len(words) + 1, options.we)) # word embedding | |
self.pe = self.model.add_lookup_parameters((len(pos_tags) + 1, options.pe)) # pos embedding | |
self.le = self.model.add_lookup_parameters((len(labels) + 1, options.le)) # label embedding | |
self.H1 = self.model.add_parameters((options.h1, self.n_word_feat * (options.pe + options.we) + self.n_label_feat * options.le)) | |
self.H2 = self.model.add_parameters((options.h2, options.h1)) | |
self.w = self.model.add_parameters((len(self.labels), options.h2)) | |
self.b = self.model.add_parameters(len(self.labels)) | |
self.h1b = self.model.add_parameters(options.h1, init=ConstInitializer(0.2)) | |
self.h2b = self.model.add_parameters(options.h2, init=ConstInitializer(0.2)) | |
def read_data(self, path): | |
data = [] | |
for line in open(path, 'r'): | |
spl = line.strip().split() | |
label = self.labels[spl[-1]] if spl[-1] in self.labels else (3 if spl[-1].startswith('LEFT') else 2) | |
words = [self.word_vocab[w] if w in self.word_vocab else 0 for w in spl[:self.n_word_feat]] | |
tags = [self.pos_vocab[t] if t in self.pos_vocab else 0 for t in | |
spl[self.n_word_feat:2 * self.n_word_feat]] | |
dep_labels = [self.dep_vocab[l] if l in self.dep_vocab else 0 for l in spl[2 * self.n_word_feat:-1]] | |
data.append(words + tags + dep_labels + [label]) | |
return data | |
def build_graph(self, batch, is_train): | |
nf = self.n_word_feat * (self.options.we + self.options.pe) + self.n_label_feat * self.options.le | |
wvecs = concatenate([lookup_batch(self.we, [b[i] for b in batch]) for i in range(self.n_word_feat)]) | |
pvecs = concatenate( | |
[lookup_batch(self.pe, [b[i] for b in batch]) for i in range(self.n_word_feat, 2 * self.n_word_feat)]) | |
lvecs = concatenate([lookup_batch(self.le, [b[i] for b in batch]) for i in | |
range(2 * self.n_word_feat, (2 * self.n_word_feat) + self.n_label_feat)]) | |
lu = reshape(concatenate([wvecs, pvecs, lvecs]), (nf, len(batch))) | |
h1 = rectify(affine_transform([self.h1b.expr(), self.H1.expr(), lu])) | |
h2 = rectify(affine_transform([self.h2b.expr(), self.H2.expr(), h1])) | |
return reshape(affine_transform([self.b.expr(), self.w.expr(), h2]), (len(self.labels),), len(batch)) | |
def score(self, str_features): | |
wids = [self.word_vocab[w] if w in self.word_vocab else 0 for w in str_features[:self.n_word_feat]] | |
pids = [self.pos_vocab[t] if t in self.pos_vocab else 0 for t in | |
str_features[self.n_word_feat:2 * self.n_word_feat]] | |
lids = [self.dep_vocab[l] if l in self.dep_vocab else 0 for l in | |
str_features[2 * self.n_word_feat: (2 * self.n_word_feat) + self.n_label_feat]] | |
we = concatenate([self.we[w] for w in wids]) | |
pe = concatenate([self.pe[t] for t in pids]) | |
le = concatenate([self.le[l] for l in lids]) | |
h1 = rectify(affine_transform([self.h1b.expr(), self.H1.expr(), concatenate([we, pe, le])])) | |
h2 = rectify(affine_transform([self.h2b.expr(),self.H2.expr(), h1])) | |
return (self.w.expr() * h2 + self.b.expr()).value() | |
def train(self, train_data, step): | |
random.shuffle(train_data) | |
batch, labels = [], [] | |
m_loss, total, start = 0, 0, time.time() | |
progress = 0 | |
for d in train_data: | |
batch.append(d[:-1]) | |
labels.append(d[-1]) | |
progress += 1 | |
if len(batch) >= self.options.batch: | |
loss = sum_batches(pickneglogsoftmax_batch(self.build_graph(batch, True), labels))/ len(labels) | |
m_loss, total = m_loss + loss.value(), total + 1 | |
batch, labels = [], [] | |
loss.backward() | |
self.trainer.update() | |
renew_cg() | |
step += 1 | |
if total % 10 == 0: | |
adv = str(round(100 * float(progress) / len(train_data), 2)) + '%' | |
print 'lr', self.trainer.learning_rate,'loss', m_loss / total, 'time', time.time() - start, 'progress', adv | |
m_loss, total, start = 0, 0, time.time() | |
yield step | |
def save(self, filename): | |
self.model.save(filename) | |
def load(self, filename): | |
self.model.populate(filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment