Last active
December 12, 2017 17:32
Star
You must be signed in to star a gist
TSG/IS18er AdventCalendar2017 12th (Qiita)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer.training import extensions | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy import interpolate | |
class unit(chainer.Chain): # 恒星、惑星、衛星などをまとめたもの | |
def __init__(self, x, y): | |
super().__init__() | |
self.dim = 3 # 動点の数 | |
with self.init_scope(): | |
self.omega = chainer.Parameter( | |
np.random.rand(self.dim, 1).astype(np.float32) * 10 - 5) | |
self.phi = chainer.Parameter( | |
np.random.rand(self.dim).astype(np.float32)) | |
self.link = chainer.Parameter( | |
np.random.rand(1, self.dim).astype(np.float32) * .5) # 各円の半径 | |
self.xbias = chainer.Parameter( | |
initializer=np.array(x, np.float32), shape=(1,)) # 恒星のx座標になる | |
self.ybias = chainer.Parameter( | |
initializer=np.array(y, np.float32), shape=(1,)) | |
def __call__(self, time): | |
_omega = F.floor(self.omega) # omegaは整数にしたい | |
phase = F.linear(np.c_[time], F.floor(_omega), self.phi) * np.pi * 2 | |
x = F.cos(phase) | |
y = F.sin(phase) | |
x = F.linear(x, self.link, self.xbias) | |
y = F.linear(y, self.link, self.ybias) | |
return x, y | |
def get_info(self, time): | |
info = [] | |
_omega = F.floor(self.omega) # omegaは整数にしたい | |
phase = F.linear(np.c_[time], F.floor(_omega), self.phi) * np.pi * 2 | |
x = F.cos(phase) | |
y = F.sin(phase) | |
base_x, base_y = self.xbias.data[0], self.ybias.data[0] | |
for i in range(self.dim): | |
_r = self.link[0, i] | |
base_x += x[:, i] * _r.data | |
base_y += y[:, i] * _r.data | |
info.append([base_x.data, base_y.data]) | |
return info | |
class Model(chainer.Chain): | |
def __init__(self, offset): | |
"""offset [[x0, y0], [x1, y1], ...]""" | |
super().__init__() | |
self.train = True | |
with self.init_scope(): | |
self.u0 = unit(*offset[0]) | |
self.u1 = unit(*offset[1]) | |
self.u2 = unit(*offset[2]) | |
self.u3 = unit(*offset[3]) | |
self.u = [self.u0, self.u1, self.u2, self.u3] | |
def calc(self, time): | |
"""最終的な点の位置を返す""" | |
x = [] | |
y = [] | |
for i in range(4): | |
unit_x, unit_y = self.u[i](time) | |
x.append(unit_x) | |
y.append(unit_y) | |
A = F.reshape(F.concat((y[0]-y[2], x[2]-x[0], y[1]-y[3], x[3]-x[1])), (-1,2,2)) | |
Mat = F.batch_inv(A) | |
vec = F.reshape(F.concat((y[0]*x[2] - x[0]*y[2], y[1]*x[3]-x[1]*y[3])), (-1,1,2)) | |
return F.matmul(vec, Mat, transb=True) | |
def __call__(self, time, teacher): | |
"""return loss""" | |
predict = self.calc(time) # (Batch * 1 * 2) | |
# teacher: (Batch * 2) | |
diff = predict - teacher.reshape(-1, 1, 2) | |
loss = F.sum(diff**2) | |
if self.train: | |
chainer.reporter.report({'loss': loss / len(time)}) | |
else: | |
chainer.reporter.report({'validation/loss': loss / len(time)}) | |
return loss | |
class TestEvaluator(extensions.Evaluator): | |
def __init__(self, test_iter, model, trainer): | |
super(TestEvaluator, self).__init__(test_iter, model) | |
self.trainer = trainer | |
def evaluate(self): | |
model = self.get_target('main') | |
model.train = False | |
ret = super(TestEvaluator, self).evaluate() | |
model.train = True | |
return ret | |
if __name__ == '__main__': | |
np.random.seed(0) | |
anchor_t = np.array([ # これは手打ち | |
[0, 1], | |
[1, 1], | |
[0.5, 1], | |
[0.5, -1] | |
]) | |
# Tの一筆書き | |
path_t_x = interpolate.interp1d( | |
np.linspace(0, 1, len(anchor_t)), | |
anchor_t[:,0])(np.linspace(0, 1, 60)).reshape(-1,1) | |
path_t_y = interpolate.interp1d( | |
np.linspace(0, 1, len(anchor_t)), | |
anchor_t[:,1])(np.linspace(0, 1, 60)).reshape(-1,1) | |
half_dataset_t = np.hstack((path_t_x, path_t_y)) | |
# 一回の周期で元の位置に戻ってほしいので逆再生したやつを追加 | |
dataset_t = np.vstack((half_dataset_t, half_dataset_t[::-1])) | |
""" | |
plt.plot(dataset_t[:,0], dataset_t[:,1], '.-',label='interpolate') | |
plt.plot(anchor_t[:,0], anchor_t[:,1], 'o', label='anchor') | |
plt.legend() | |
plt.show() | |
""" | |
time_iterator = np.linspace(0, 1, len(dataset_t)).astype(np.float32).reshape(-1, 1) | |
dataset = chainer.datasets.TupleDataset(time_iterator, dataset_t) | |
train_iter = chainer.iterators.SerialIterator(dataset, 1) | |
test_iter = chainer.iterators.SerialIterator(dataset, len(dataset), repeat=False, shuffle=False) | |
model = Model([[-1,-1], [-1, 2], [2, 2], [2, -1]]) | |
optimizer = chainer.optimizers.Adam() | |
optimizer.setup(model) | |
updater = chainer.training.StandardUpdater(train_iter, optimizer) | |
trainer = chainer.training.Trainer(updater, (50, 'epoch')) | |
trainer.extend(TestEvaluator(test_iter, model, trainer)) | |
trainer.extend(extensions.LogReport()) | |
trainer.extend(extensions.ProgressBar()) | |
trainer.extend(extensions.PrintReport( | |
['epoch', 'loss', 'validation/loss', 'elapsed_time'] | |
)) | |
""" | |
points = model.calc(time_iterator) | |
x = points[...,0].data | |
y = points[...,1].data | |
plt.plot(x, y, 'b.-') | |
plt.plot(path_t_x, path_t_y, 'r-') | |
plt.show() | |
""" | |
trainer.run() | |
points = model.calc(time_iterator) | |
x = points[...,0].data | |
y = points[...,1].data | |
plt.plot(x, y, 'b.-', label='trained') | |
plt.plot(path_t_x, path_t_y, 'r-', label='teacher') | |
chainer.serializers.save_npz('model_dim3_T', model) | |
plt.savefig('20171212_dim3.png') | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graph_optimizer import Model, unit | |
import chainer | |
import chainer.functions as F | |
from chainer.serializers import load_npz | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import ArtistAnimation | |
import sys | |
import matplotlib.patches as patches | |
offset = [ | |
[0,0] for i in range(4) | |
] | |
model = Model(offset) | |
load_npz('model_dim3_T', model) | |
time = np.linspace(0, 1, 100).astype(np.float32).reshape(-1, 1) | |
fig, ax = plt.subplots() | |
plt.gca().set_aspect('equal', adjustable='box') | |
ax.tick_params(labelbottom="off",bottom="off") # x軸の削除 | |
ax.tick_params(labelleft="off",left="off") # y軸の削除 | |
ax.set_xticklabels([]) | |
plt.box("off") #枠線の削除 | |
artists = [] # animation frames | |
for u in model.u: | |
W = u.link.data | |
# 縮退 | |
W[np.where(abs(W) < 0.05)] = 0 | |
# 衛星のほうが小さくなるように並び替え | |
argsort = np.argsort(-abs(W[0,:])) | |
omega = F.floor(u.omega).data[argsort] | |
phi = u.phi.data[argsort] | |
W = W[0, argsort] | |
u.link = chainer.Parameter(initializer=W, shape=(1, u.dim)) | |
u.omega = chainer.Parameter(initializer=omega, shape=(u.dim, 1)) | |
u.phi = chainer.Parameter(initializer=phi, shape=(u.dim,)) | |
point = model.calc(time).data | |
info = np.array([u.get_info(time) for u in model.u]) # (u,dim,xy,time) [[[x[0]: shape(time,), y[0]], [x[1], y[1]]...]] | |
R = [u.link.data[0, :] for u in model.u] | |
color = 'rrrr' | |
line = ax.plot(point[:,0,0], point[:,0,1], color='k', alpha=0.5, linewidth=3) | |
for i in range(len(time)): | |
im = [] | |
im.extend(line) | |
im.extend(ax.plot(point[i,0,0], point[i,0,1], 'ko')) | |
im.append(ax.add_patch(patches.ConnectionPatch(info[0,-1,:,i], info[2,-1,:,i], coordsA='data', color='r'))) | |
im.append(ax.add_patch(patches.ConnectionPatch(info[1,-1,:,i], info[3,-1,:,i], coordsA='data', color='r'))) | |
for ui, u in enumerate(model.u): | |
xbias = u.xbias.data[0] | |
ybias = u.ybias.data[0] | |
im.extend(ax.plot(xbias, ybias, color[ui]+'o')) | |
im.append(ax.add_patch(patches.Circle((xbias, ybias), radius=R[ui][0], fill=False, color=color[ui], linewidth=0.3))) | |
for dim, xy in enumerate(info[ui]): | |
im.extend(ax.plot(xy[0,i], xy[1,i], color[ui]+'o', markersize=1)) | |
if dim < u.dim-1: | |
im.append(ax.add_patch(patches.Circle(xy[:,i], radius=R[ui][dim+1], fill=False, color=color[ui], linewidth=0.3))) | |
if dim == 0: | |
im.append(ax.add_patch(patches.ConnectionPatch((xbias, ybias), xy[:,i], coordsA='data', color='r', linewidth=0.2))) | |
else: | |
im.append(ax.add_patch(patches.ConnectionPatch(info[ui,dim-1,:,i], xy[:,i], coordsA='data', color='r', linewidth=0.2))) | |
artists.append(im) | |
print("ok") | |
anim = ArtistAnimation(fig, artists, interval=100) | |
anim.save('out.gif', writer='imagemagick') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment