Skip to content

Instantly share code, notes, and snippets.

@piyo7
Created November 13, 2016 08:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piyo7/27ddaed4242853b79f100ed265bcb543 to your computer and use it in GitHub Desktop.
Save piyo7/27ddaed4242853b79f100ed265bcb543 to your computer and use it in GitHub Desktop.
メモリを操作するRNNでソートアルゴリズム(可変長&順序フラグあり)を機械学習できたよっ! ref: http://qiita.com/piyo7/items/3f94686d2802c290e60b
# 20000
in [7, 4] ASC
ans [4, 7]
out [4, 7] 1.0
in [1, 6, 3] ASC
ans [1, 3, 6]
out [1, 3, 6] 1.0
in [2, 6, 1, 2] ASC
ans [1, 2, 2, 6]
out [1, 2, 2, 6] 1.0
in [1, 0, 6, 1, 7] DESC
ans [7, 6, 1, 1, 0]
out [7, 6, 1, 1, 0] 1.0
in [5, 1, 2, 1, 4, 6] ASC
ans [1, 1, 2, 4, 5, 6]
out [1, 1, 2, 4, 5, 6] 1.0
in [5, 4, 3, 5, 5, 8, 5] DESC
ans [8, 5, 5, 5, 5, 4, 3]
out [8, 5, 5, 5, 5, 4, 3] 1.0
in [9, 3, 6, 9, 5, 2, 9, 5] ASC
ans [2, 3, 5, 5, 6, 9, 9, 9]
out [2, 3, 5, 5, 6, 9, 9, 9] 1.0
in [0, 6, 5, 8, 4, 6, 0, 8, 0] ASC
ans [0, 0, 0, 4, 5, 6, 6, 8, 8]
out [0, 0, 0, 4, 5, 6, 6, 8, 8] 1.0
in [1, 6, 3] ASC
ans [1, 3, 6]
out [1, 3, 6] 1.0
len_number = 10
X = len_number + 2
Y = len_number
N = 20
W = X
R = 2
mdl = DNC(X, Y, N, W, R)
opt = optimizers.Adam()
opt.setup(mdl)
def run(low, high, train=False):
order = np.random.randint(0, 2)
content = [np.random.randint(0, len_number) for _ in range(np.random.randint(low, high))]
sorted_content = sorted(content, reverse=(order != 0))
x_seq_list = map(lambda i: onehot(i, X), [len_number + order] + content) + [np.zeros(X).astype(np.float32)] * len(content)
t_seq_list = ([None] * (1 + len(content))) + map(lambda i: onehot(i, Y), sorted_content)
result = []
loss = 0.0
for (x_seq, t_seq) in zip(x_seq_list, t_seq_list):
y = mdl(Variable(x_seq.reshape(1, X)))
if t_seq is not None:
t = Variable(t_seq.reshape(1, Y))
if train:
loss += (y - t) ** 2
else:
result.append(np.argmax(y.data))
mdl.reset_state()
if train:
mdl.cleargrads()
loss.grad = np.ones(loss.data.shape, dtype=np.float32)
loss.backward()
loss.unchain_backward()
opt.update()
else:
print 'in ', content, 'ASC' if order == 0 else 'DESC'
print 'ans', sorted_content
print 'out', result, sum(1.0 if s == r else 0.0 for (s, r) in zip(sorted_content, result)) / len(result)
for i in range(100000):
run(2, 9, train=True)
if i % 100 == 0:
print
print '#', i
for j in range(2, 10):
print
run(j, j + 1)
$ python --version
Python 2.7.11
$ pip freeze
chainer==1.17.0
filelock==2.0.7
nose==1.3.7
numpy==1.11.2
protobuf==3.1.0.post1
six==1.10.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment