Created
February 4, 2018 08:34
-
-
Save lyger/e1613a0481756ca8cd387549bbe80b2f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import _dynet as dy | |
import numpy as np | |
from collections import OrderedDict | |
# End of word token. | |
EOW = "<EOW>" | |
def sample_softmax(x): | |
return min((x.cumsum() < np.random.rand()).sum(), len(x) - 1) | |
def isDynetParam(p): | |
return isinstance(p, dy.Parameters) or \ | |
isinstance(p, dy.LookupParameters) or \ | |
isinstance(p, dy._RNNBuilder) | |
class WordGenerator: | |
def __init__(self, | |
vocabulary, | |
embedding_dim, | |
hidden_dim, | |
layers, | |
model, | |
case_sensitive=True): | |
if not case_sensitive: | |
vocabulary = vocabulary.lower() | |
vocabulary = np.unique(list(vocabulary) + [EOW]) | |
V = len(vocabulary) | |
self._model = model | |
self._parameters = OrderedDict() | |
self.c2i = {c: i for i, c in enumerate(vocabulary)} | |
self.i2c = {i: c for i, c in enumerate(vocabulary)} | |
self.lookup = model.add_lookup_parameters((V, embedding_dim)) | |
self.lstm = dy.VanillaLSTMBuilder(layers, | |
embedding_dim, | |
hidden_dim, | |
model) | |
self.W = model.add_parameters((V, hidden_dim)) | |
self.b = model.add_parameters((V)) | |
def word_to_indices(self, word): | |
word = [EOW] + list(word) + [EOW] | |
return [self.c2i[c] for c in word] | |
def train_batch(self, words): | |
losses = [] | |
W = dy.parameter(self.W) | |
b = dy.parameter(self.b) | |
for word in words: | |
wlosses = [] | |
word = self.word_to_indices(word) | |
s = self.lstm.initial_state() | |
for c, next_c in zip(word, word[1:]): | |
s = s.add_input(self.lookup[c]) | |
unnormalized = dy.affine_transform([b, W, s.output()]) | |
wlosses.append(dy.pickneglogsoftmax(unnormalized, next_c)) | |
losses.append(dy.esum(wlosses) / len(word)) | |
return dy.esum(losses) / len(words) | |
def generate(self, num, limit=40, beam=3): | |
dy.renew_cg() | |
generated = [] | |
W = dy.parameter(self.W) | |
b = dy.parameter(self.b) | |
for wordi in range(num): | |
# Initialize the LSTM state with EOW token. | |
start_state = self.lstm.initial_state() | |
start_state = start_state.add_input(self.lookup[self.c2i[EOW]]) | |
best_states = [('', start_state, 0)] | |
final_hypotheses = [] | |
# Perform beam search. | |
while len(final_hypotheses) < beam and len(best_states) > 0: | |
new_states = [] | |
for hyp, s, p in best_states: | |
# Cutoff when we exceed the character limit. | |
if len(hyp) >= limit: | |
final_hypotheses.append((hyp, p)) | |
continue | |
# Get the prediction from the current LSTM state. | |
unnormalized = dy.affine_transform([b, W, s.output()]) | |
softmax = dy.softmax(unnormalized).npvalue() | |
# Sample beam number of times. | |
for beami in range(beam): | |
ci = sample_softmax(softmax) | |
c = self.i2c[ci] | |
next_p = softmax[ci] | |
logp = p - np.log(next_p) | |
if c == EOW: | |
# Add final hypothesis if we reach end of word. | |
final_hypotheses.append((hyp, logp)) | |
else: | |
# Else add to states to search next time step. | |
new_states.append((hyp + c, | |
s.add_input(self.lookup[ci]), | |
logp)) | |
# Sort and prune the states to within the beam. | |
new_states.sort(key=lambda t: t[-1]) | |
best_states = new_states[:beam] | |
final_hypotheses.sort(key=lambda t: t[-1]) | |
generated.append(final_hypotheses[0][0]) | |
return generated | |
def save(self, fname): | |
dy.save(fname, [v for k, v in self._parameters.items()]) | |
def load(self, fname): | |
params = dy.load(fname, self._model) | |
for name, param in zip(self._parameters, params): | |
self.__setattr__(name, param) | |
def __setattr__(self, name, value): | |
'''When we add an attribute, add it to the internal parameter list if | |
the attribute is of type dy.Parameters. | |
''' | |
if isDynetParam(value): | |
self._parameters[name] = value | |
super().__setattr__(name, value) | |
def train(network, trainer, words, epochs, batch_size=100): | |
words = np.array(words) | |
last_loss = None | |
for enum in range(epochs): | |
shuf = np.random.permutation(len(words)) | |
words = words[shuf] | |
eloss = 0 | |
bnum = 0 | |
for bi in range(0, len(words), batch_size): | |
bwords = words[bi * batch_size:(bi + 1) * batch_size] | |
if len(bwords) < 1: | |
continue | |
dy.renew_cg() | |
loss = network.train_batch(bwords) | |
eloss += loss.value() | |
loss.backward() | |
trainer.update() | |
bnum += 1 | |
eloss = eloss / bnum | |
if last_loss: | |
last_loss = 0.95 * last_loss + 0.05 * eloss | |
else: | |
last_loss = eloss | |
print('Epoch {} loss: {:.6f} Running avg.: {:.6f}'.format( | |
enum + 1, eloss, last_loss)) | |
return last_loss | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from character_lstm import WordGenerator, train, dy | |
from argparse import ArgumentParser | |
from sys import stdout | |
parser = ArgumentParser() | |
parser.add_argument('-m', '--model') | |
parser.add_argument('-g', '--generate', action='store_true') | |
parser.add_argument('-c', '--corpus') | |
opts = parser.parse_args() | |
dy.init() | |
EPSILON = 1e-3 | |
EMB_DIM = 10 | |
HID_DIM = 16 | |
LAYERS = 1 | |
NUM_GEN = 500 | |
m = dy.ParameterCollection() | |
# Generate mode. | |
if opts.generate: | |
with open(opts.model + '.vocab', 'r') as vfile: | |
vocab = vfile.readline().replace('\n', '') | |
generator = WordGenerator(vocab, EMB_DIM, HID_DIM, LAYERS, m) | |
generator.load(opts.model) | |
skipnames = set() | |
if opts.corpus: | |
with open(opts.corpus, 'r') as cfile: | |
skipnames.update([l.strip() for l in cfile]) | |
gennames = [] | |
while len(gennames) < NUM_GEN: | |
for n in generator.generate(100, beam=2): | |
if n in skipnames: | |
continue | |
gennames.append(n) | |
skipnames.add(n) | |
for n in sorted(gennames): | |
stdout.write(n + '\n') | |
# Train mode. | |
else: | |
names = [] | |
vocab = set() | |
with open(opts.corpus, 'r') as ifile: | |
for line in ifile: | |
line = line.strip() | |
names.append(line) | |
vocab.update(list(line)) | |
vocab = ''.join([c for c in sorted(vocab)]) | |
generator = WordGenerator(vocab, EMB_DIM, HID_DIM, LAYERS, m) | |
trainer = dy.AdadeltaTrainer(m) | |
# Train at least 2000 epochs to start with. | |
lloss = train(generator, trainer, names, 2000) | |
# Then train until average loss stops decreasing. | |
while True: | |
nloss = train(generator, trainer, names, 50) | |
if lloss - nloss < EPSILON: | |
break | |
lloss = nloss | |
generator.save(opts.model) | |
with open(opts.model + '.vocab', 'w') as vfile: | |
vfile.write(vocab + '\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment