Last active
September 11, 2016 09:55
-
-
Save cympfh/9c592779a63ba804294a16a7402ae910 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import re | |
import numpy | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
OUT = 0 | |
BEGIN = 1 | |
INNER = 2 | |
class Vocabulary(): | |
""" | |
chars set | |
""" | |
def __init__(self): | |
self.char2id = {} | |
self.len = 0 | |
self.push('<unknown>') | |
def __len__(self): | |
return self.len | |
def normalize(self, char): | |
if re.match(u'[ぁ-ん]', char): | |
return '<hiragana>' | |
if re.match(u'[ァ-ン]', char): | |
return '<katakana>' | |
if re.match(u'[一-龥]', char): | |
return '<kanji>' | |
return char | |
def push(self, char): | |
char = self.normalize(char) | |
if char in self.char2id: | |
return self.char2id[char] | |
self.char2id[char] = self.len | |
self.len += 1 | |
return self.len - 1 | |
def __getitem__(self, char): | |
char = self.normalize(char) | |
if char in self.char2id: | |
return self.char2id[char] | |
return self.char2id['<unknown>'] | |
voc = Vocabulary() | |
voc.push('<s>') | |
voc.push('</s>') | |
# read train-data from stdin | |
dataset = [] | |
for line in sys.stdin: | |
""" | |
line = "abc<>def<>ghi" | |
=> x = [abc bcd cde .. ghi] | |
=> y = [0 0 0 1 2 2 0 0 0] | |
""" | |
chars = list(line.strip()) | |
x = [] | |
y = [] | |
state = OUT | |
i = 0 | |
while i < len(chars): | |
if chars[i] == '<' and chars[i+1] == '>': | |
if state == OUT: | |
state = BEGIN | |
else: | |
state = OUT | |
i += 2 | |
else: | |
x.append(voc.push(chars[i])) | |
y.append(state) | |
i += 1 | |
if state == BEGIN: | |
state = INNER | |
# alignment | |
while len(x) < 200: | |
x.append(voc['</s>']) | |
y.append(OUT) | |
x = [voc['<s>'], voc['<s>']] + x + [voc['</s>'], voc['</s>']] | |
x = [[x[i-2], x[i-1], x[i], x[i+1], x[i+2]] for i in range(2, len(x) - 2)] # [1,2,3,4] => [[1,2,3], [2,3,4]] | |
x = numpy.array(x, 'i') | |
y = numpy.array(y, 'i') | |
dataset.append((x, y)) | |
# my model | |
class IconDetector(chainer.Chain): | |
def __init__(self, n, m): | |
super().__init__( | |
embed=L.EmbedID(n, n//2), | |
lin=L.Linear(n//2 * 5, m), | |
crf=L.CRF1d(m) | |
) | |
def forward(self, xs): | |
l = xs.data.shape[1] | |
ys = [] | |
for i in range(l): | |
x = F.concat([self.embed(xs[:, i, k]) for k in range(5)], axis=1) | |
y = self.lin(F.tanh(x)) | |
ys.append(y) | |
return ys | |
def __call__(self, xs, ts): | |
"""error function""" | |
ys = self.forward(xs) | |
ts = [ts[:, i] for i in range(ts.data.shape[1])] | |
loss = self.crf(ys, ts) | |
print('loss', loss.data) | |
return loss | |
def predict(self, xs): | |
ts = self.forward(xs) | |
_, ys = self.crf.argmax(ts) | |
return ys | |
model = IconDetector(len(voc), 3) | |
opt = chainer.optimizers.Adam(alpha=0.01) | |
opt.setup(model) | |
iterator = chainer.iterators.SerialIterator(dataset, 32) | |
updater = chainer.training.StandardUpdater(iterator, opt) | |
trainer = chainer.training.Trainer(updater, (20, 'epoch')) | |
trainer.run() | |
# test predict | |
test = '世界が平和でありますように( ;´Д`)' | |
x = [voc['<s>'], voc['<s>']] + [voc[c] for c in list(test)] + [voc['</s>'], voc['</s>']] | |
x = [[x[i-2], x[i-1], x[i], x[i+1], x[i+2]] for i in range(2, len(x) - 2)] | |
x = numpy.array(x, 'i') | |
x = x.reshape((1, *x.data.shape)) | |
y = model.predict(x) | |
y = [item[0] for item in y] | |
result = [] | |
chars = list(test) | |
for i, c in enumerate(chars): | |
if y[i] != OUT and (i == 0 or y[i-1] == OUT): | |
result.append('<>') | |
if y[i] == OUT and i > 0 and y[i-1] != OUT: | |
result.append('<>') | |
result.append(c) | |
if y[-1] != OUT: | |
result.append('<>') | |
print(y) | |
print(''.join(result)) | |
# => | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2] | |
# 世界が平和でありますように<>( ;´Д`)<> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment