Created
November 13, 2016 08:28
-
-
Save piyo7/27ddaed4242853b79f100ed265bcb543 to your computer and use it in GitHub Desktop.
メモリを操作するRNNでソートアルゴリズム(可変長&順序フラグあり)を機械学習できたよっ! ref: http://qiita.com/piyo7/items/3f94686d2802c290e60b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 20000 | |
in [7, 4] ASC | |
ans [4, 7] | |
out [4, 7] 1.0 | |
in [1, 6, 3] ASC | |
ans [1, 3, 6] | |
out [1, 3, 6] 1.0 | |
in [2, 6, 1, 2] ASC | |
ans [1, 2, 2, 6] | |
out [1, 2, 2, 6] 1.0 | |
in [1, 0, 6, 1, 7] DESC | |
ans [7, 6, 1, 1, 0] | |
out [7, 6, 1, 1, 0] 1.0 | |
in [5, 1, 2, 1, 4, 6] ASC | |
ans [1, 1, 2, 4, 5, 6] | |
out [1, 1, 2, 4, 5, 6] 1.0 | |
in [5, 4, 3, 5, 5, 8, 5] DESC | |
ans [8, 5, 5, 5, 5, 4, 3] | |
out [8, 5, 5, 5, 5, 4, 3] 1.0 | |
in [9, 3, 6, 9, 5, 2, 9, 5] ASC | |
ans [2, 3, 5, 5, 6, 9, 9, 9] | |
out [2, 3, 5, 5, 6, 9, 9, 9] 1.0 | |
in [0, 6, 5, 8, 4, 6, 0, 8, 0] ASC | |
ans [0, 0, 0, 4, 5, 6, 6, 8, 8] | |
out [0, 0, 0, 4, 5, 6, 6, 8, 8] 1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
in [1, 6, 3] ASC | |
ans [1, 3, 6] | |
out [1, 3, 6] 1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
len_number = 10 | |
X = len_number + 2 | |
Y = len_number | |
N = 20 | |
W = X | |
R = 2 | |
mdl = DNC(X, Y, N, W, R) | |
opt = optimizers.Adam() | |
opt.setup(mdl) | |
def run(low, high, train=False): | |
order = np.random.randint(0, 2) | |
content = [np.random.randint(0, len_number) for _ in range(np.random.randint(low, high))] | |
sorted_content = sorted(content, reverse=(order != 0)) | |
x_seq_list = map(lambda i: onehot(i, X), [len_number + order] + content) + [np.zeros(X).astype(np.float32)] * len(content) | |
t_seq_list = ([None] * (1 + len(content))) + map(lambda i: onehot(i, Y), sorted_content) | |
result = [] | |
loss = 0.0 | |
for (x_seq, t_seq) in zip(x_seq_list, t_seq_list): | |
y = mdl(Variable(x_seq.reshape(1, X))) | |
if t_seq is not None: | |
t = Variable(t_seq.reshape(1, Y)) | |
if train: | |
loss += (y - t) ** 2 | |
else: | |
result.append(np.argmax(y.data)) | |
mdl.reset_state() | |
if train: | |
mdl.cleargrads() | |
loss.grad = np.ones(loss.data.shape, dtype=np.float32) | |
loss.backward() | |
loss.unchain_backward() | |
opt.update() | |
else: | |
print 'in ', content, 'ASC' if order == 0 else 'DESC' | |
print 'ans', sorted_content | |
print 'out', result, sum(1.0 if s == r else 0.0 for (s, r) in zip(sorted_content, result)) / len(result) | |
for i in range(100000): | |
run(2, 9, train=True) | |
if i % 100 == 0: | |
print '#', i | |
for j in range(2, 10): | |
run(j, j + 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python --version | |
Python 2.7.11 | |
$ pip freeze | |
chainer==1.17.0 | |
filelock==2.0.7 | |
nose==1.3.7 | |
numpy==1.11.2 | |
protobuf==3.1.0.post1 | |
six==1.10.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment