Skip to content

Instantly share code, notes, and snippets.

Created December 5, 2016 12:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nihemak/34d35770d8d9fb2f0ace47add7edf150 to your computer and use it in GitHub Desktop.
Save nihemak/34d35770d8d9fb2f0ace47add7edf150 to your computer and use it in GitHub Desktop.
import numpy as np
import chainer
from chainer import Variable, optimizers, serializers, Chain
import chainer.functions as F
import chainer.links as L
# 翻訳クラス(Encoder-Decoder翻訳モデルにAttentionを導入したモデルを使う)
class Translator(chainer.Chain):
def __init__(self, debug = False, source = 'en.txt', target = 'ja.txt', embed_size = 100):
self.embed_size = embed_size
self.source_lines, self.source_word2id, _ = self.load_language(source)
self.target_lines, self.target_word2id, self.target_id2word = self.load_language(target)
source_size = len(self.source_word2id)
target_size = len(self.target_word2id)
super(Translator, self).__init__(
embed_x = L.EmbedID(source_size, embed_size),
embed_y = L.EmbedID(target_size, embed_size),
H = L.LSTM(embed_size, embed_size),
Wc1 = L.Linear(embed_size, embed_size),
Wc2 = L.Linear(embed_size, embed_size),
W = L.Linear(embed_size, target_size),
self.optimizer = optimizers.Adam()
if debug:
print("embed_size: {0}".format(embed_size), end="")
print(", source_size: {0}".format(source_size), end="")
print(", target_size: {0}".format(target_size))
def learn(self, debug = False):
line_num = len(self.source_lines) - 1
for i in range(line_num):
source_words = self.source_lines[i].split()
target_words = self.target_lines[i].split()
loss = self.loss(source_words, target_words)
if debug:
print("{0} / {1} line finished.".format(i + 1, line_num))
def test(self, source_words):
bar_h_i_list = self.h_i_list(source_words, True)
x_i = self.embed_x(Variable(np.array([self.source_word2id['<eos>']], dtype=np.int32), volatile='on'))
h_t = self.H(x_i)
c_t = self.c_t(bar_h_i_list,[0], True)
result = []
bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
wid = np.argmax(F.softmax(self.W(bar_h_t)).data[0])
loop = 0
while (wid != self.target_word2id['<eos>']) and (loop <= 30):
y_i = self.embed_y(Variable(np.array([wid], dtype=np.int32), volatile='on'))
h_t = self.H(y_i)
c_t = self.c_t(bar_h_i_list,, True)
bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
wid = np.argmax(F.softmax(self.W(bar_h_t)).data[0])
loop += 1
return result
# 損失を求める
def loss(self, source_words, target_words):
bar_h_i_list = self.h_i_list(source_words)
x_i = self.embed_x(Variable(np.array([self.source_word2id['<eos>']], dtype=np.int32)))
h_t = self.H(x_i)
c_t = self.c_t(bar_h_i_list,[0])
bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
tx = Variable(np.array([self.target_word2id[target_words[0]]], dtype=np.int32))
accum_loss = F.softmax_cross_entropy(self.W(bar_h_t), tx)
for i in range(len(target_words)):
wid = self.target_word2id[target_words[i]]
y_i = self.embed_y(Variable(np.array([wid], dtype=np.int32)))
h_t = self.H(y_i)
c_t = self.c_t(bar_h_i_list,
bar_h_t = F.tanh(self.Wc1(c_t) + self.Wc2(h_t))
next_wid = self.target_word2id['<eos>'] if (i == len(target_words) - 1) else self.target_word2id[target_words[i+1]]
tx = Variable(np.array([next_wid], dtype=np.int32))
loss = F.softmax_cross_entropy(self.W(bar_h_t), tx)
accum_loss = loss if accum_loss is None else accum_loss + loss
return accum_loss
# h_i のリストを求める
def h_i_list(self, words, test = False):
h_i_list = []
volatile = 'on' if test else 'off'
for word in words:
wid = self.source_word2id[word]
x_i = self.embed_x(Variable(np.array([wid], dtype=np.int32), volatile=volatile))
h_i = self.H(x_i)
return h_i_list
# context vector c_t を求める
def c_t(self, bar_h_i_list, h_t, test = False):
s = 0.0
for bar_h_i in bar_h_i_list:
s += np.exp(
c_t = np.zeros(self.embed_size)
for bar_h_i in bar_h_i_list:
alpha_t_i = np.exp( / s
c_t += alpha_t_i * bar_h_i
volatile = 'on' if test else 'off'
c_t = Variable(np.array([c_t]).astype(np.float32), volatile=volatile)
return c_t
# 文章リストを読み込む
def load_language(self, filename):
word2id = {}
lines = open(filename).read().split('\n')
for i in range(len(lines)):
sentence = lines[i].split()
for word in sentence:
if word not in word2id:
word2id[word] = len(word2id)
word2id['<eos>'] = len(word2id)
id2word = {v:k for k, v in word2id.items()}
return [lines, word2id, id2word]
# モデルを読み込む
def load_model(self, filename):
serializers.load_npz(filename, self)
# モデルを書き出す
def save_model(self, filename):
serializers.save_npz(filename, self)
Copy link

nihemak commented Dec 5, 2016

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment