Skip to content

Instantly share code, notes, and snippets.

@cympfh
Last active September 11, 2016 09:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cympfh/9c592779a63ba804294a16a7402ae910 to your computer and use it in GitHub Desktop.
Save cympfh/9c592779a63ba804294a16a7402ae910 to your computer and use it in GitHub Desktop.
import sys
import re
import numpy
import chainer
import chainer.functions as F
import chainer.links as L
OUT = 0
BEGIN = 1
INNER = 2
class Vocabulary():
"""
chars set
"""
def __init__(self):
self.char2id = {}
self.len = 0
self.push('<unknown>')
def __len__(self):
return self.len
def normalize(self, char):
if re.match(u'[ぁ-ん]', char):
return '<hiragana>'
if re.match(u'[ァ-ン]', char):
return '<katakana>'
if re.match(u'[一-龥]', char):
return '<kanji>'
return char
def push(self, char):
char = self.normalize(char)
if char in self.char2id:
return self.char2id[char]
self.char2id[char] = self.len
self.len += 1
return self.len - 1
def __getitem__(self, char):
char = self.normalize(char)
if char in self.char2id:
return self.char2id[char]
return self.char2id['<unknown>']
voc = Vocabulary()
voc.push('<s>')
voc.push('</s>')
# read train-data from stdin
dataset = []
for line in sys.stdin:
"""
line = "abc<>def<>ghi"
=> x = [abc bcd cde .. ghi]
=> y = [0 0 0 1 2 2 0 0 0]
"""
chars = list(line.strip())
x = []
y = []
state = OUT
i = 0
while i < len(chars):
if chars[i] == '<' and chars[i+1] == '>':
if state == OUT:
state = BEGIN
else:
state = OUT
i += 2
else:
x.append(voc.push(chars[i]))
y.append(state)
i += 1
if state == BEGIN:
state = INNER
# alignment
while len(x) < 200:
x.append(voc['</s>'])
y.append(OUT)
x = [voc['<s>'], voc['<s>']] + x + [voc['</s>'], voc['</s>']]
x = [[x[i-2], x[i-1], x[i], x[i+1], x[i+2]] for i in range(2, len(x) - 2)] # [1,2,3,4] => [[1,2,3], [2,3,4]]
x = numpy.array(x, 'i')
y = numpy.array(y, 'i')
dataset.append((x, y))
# my model
class IconDetector(chainer.Chain):
def __init__(self, n, m):
super().__init__(
embed=L.EmbedID(n, n//2),
lin=L.Linear(n//2 * 5, m),
crf=L.CRF1d(m)
)
def forward(self, xs):
l = xs.data.shape[1]
ys = []
for i in range(l):
x = F.concat([self.embed(xs[:, i, k]) for k in range(5)], axis=1)
y = self.lin(F.tanh(x))
ys.append(y)
return ys
def __call__(self, xs, ts):
"""error function"""
ys = self.forward(xs)
ts = [ts[:, i] for i in range(ts.data.shape[1])]
loss = self.crf(ys, ts)
print('loss', loss.data)
return loss
def predict(self, xs):
ts = self.forward(xs)
_, ys = self.crf.argmax(ts)
return ys
model = IconDetector(len(voc), 3)
opt = chainer.optimizers.Adam(alpha=0.01)
opt.setup(model)
iterator = chainer.iterators.SerialIterator(dataset, 32)
updater = chainer.training.StandardUpdater(iterator, opt)
trainer = chainer.training.Trainer(updater, (20, 'epoch'))
trainer.run()
# test predict
test = '世界が平和でありますように( ;´Д`)'
x = [voc['<s>'], voc['<s>']] + [voc[c] for c in list(test)] + [voc['</s>'], voc['</s>']]
x = [[x[i-2], x[i-1], x[i], x[i+1], x[i+2]] for i in range(2, len(x) - 2)]
x = numpy.array(x, 'i')
x = x.reshape((1, *x.data.shape))
y = model.predict(x)
y = [item[0] for item in y]
result = []
chars = list(test)
for i, c in enumerate(chars):
if y[i] != OUT and (i == 0 or y[i-1] == OUT):
result.append('<>')
if y[i] == OUT and i > 0 and y[i-1] != OUT:
result.append('<>')
result.append(c)
if y[-1] != OUT:
result.append('<>')
print(y)
print(''.join(result))
# =>
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2]
# 世界が平和でありますように<>( ;´Д`)<>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment