Skip to content

Instantly share code, notes, and snippets.

@Yosshi999
Last active December 12, 2017 17:32
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save Yosshi999/078c4b4622c7c29f263e05424f3287ca to your computer and use it in GitHub Desktop.
TSG/IS18er AdventCalendar2017 12th (Qiita)
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
class unit(chainer.Chain): # 恒星、惑星、衛星などをまとめたもの
def __init__(self, x, y):
super().__init__()
self.dim = 3 # 動点の数
with self.init_scope():
self.omega = chainer.Parameter(
np.random.rand(self.dim, 1).astype(np.float32) * 10 - 5)
self.phi = chainer.Parameter(
np.random.rand(self.dim).astype(np.float32))
self.link = chainer.Parameter(
np.random.rand(1, self.dim).astype(np.float32) * .5) # 各円の半径
self.xbias = chainer.Parameter(
initializer=np.array(x, np.float32), shape=(1,)) # 恒星のx座標になる
self.ybias = chainer.Parameter(
initializer=np.array(y, np.float32), shape=(1,))
def __call__(self, time):
_omega = F.floor(self.omega) # omegaは整数にしたい
phase = F.linear(np.c_[time], F.floor(_omega), self.phi) * np.pi * 2
x = F.cos(phase)
y = F.sin(phase)
x = F.linear(x, self.link, self.xbias)
y = F.linear(y, self.link, self.ybias)
return x, y
def get_info(self, time):
info = []
_omega = F.floor(self.omega) # omegaは整数にしたい
phase = F.linear(np.c_[time], F.floor(_omega), self.phi) * np.pi * 2
x = F.cos(phase)
y = F.sin(phase)
base_x, base_y = self.xbias.data[0], self.ybias.data[0]
for i in range(self.dim):
_r = self.link[0, i]
base_x += x[:, i] * _r.data
base_y += y[:, i] * _r.data
info.append([base_x.data, base_y.data])
return info
class Model(chainer.Chain):
def __init__(self, offset):
"""offset [[x0, y0], [x1, y1], ...]"""
super().__init__()
self.train = True
with self.init_scope():
self.u0 = unit(*offset[0])
self.u1 = unit(*offset[1])
self.u2 = unit(*offset[2])
self.u3 = unit(*offset[3])
self.u = [self.u0, self.u1, self.u2, self.u3]
def calc(self, time):
"""最終的な点の位置を返す"""
x = []
y = []
for i in range(4):
unit_x, unit_y = self.u[i](time)
x.append(unit_x)
y.append(unit_y)
A = F.reshape(F.concat((y[0]-y[2], x[2]-x[0], y[1]-y[3], x[3]-x[1])), (-1,2,2))
Mat = F.batch_inv(A)
vec = F.reshape(F.concat((y[0]*x[2] - x[0]*y[2], y[1]*x[3]-x[1]*y[3])), (-1,1,2))
return F.matmul(vec, Mat, transb=True)
def __call__(self, time, teacher):
"""return loss"""
predict = self.calc(time) # (Batch * 1 * 2)
# teacher: (Batch * 2)
diff = predict - teacher.reshape(-1, 1, 2)
loss = F.sum(diff**2)
if self.train:
chainer.reporter.report({'loss': loss / len(time)})
else:
chainer.reporter.report({'validation/loss': loss / len(time)})
return loss
class TestEvaluator(extensions.Evaluator):
def __init__(self, test_iter, model, trainer):
super(TestEvaluator, self).__init__(test_iter, model)
self.trainer = trainer
def evaluate(self):
model = self.get_target('main')
model.train = False
ret = super(TestEvaluator, self).evaluate()
model.train = True
return ret
if __name__ == '__main__':
np.random.seed(0)
anchor_t = np.array([ # これは手打ち
[0, 1],
[1, 1],
[0.5, 1],
[0.5, -1]
])
# Tの一筆書き
path_t_x = interpolate.interp1d(
np.linspace(0, 1, len(anchor_t)),
anchor_t[:,0])(np.linspace(0, 1, 60)).reshape(-1,1)
path_t_y = interpolate.interp1d(
np.linspace(0, 1, len(anchor_t)),
anchor_t[:,1])(np.linspace(0, 1, 60)).reshape(-1,1)
half_dataset_t = np.hstack((path_t_x, path_t_y))
# 一回の周期で元の位置に戻ってほしいので逆再生したやつを追加
dataset_t = np.vstack((half_dataset_t, half_dataset_t[::-1]))
"""
plt.plot(dataset_t[:,0], dataset_t[:,1], '.-',label='interpolate')
plt.plot(anchor_t[:,0], anchor_t[:,1], 'o', label='anchor')
plt.legend()
plt.show()
"""
time_iterator = np.linspace(0, 1, len(dataset_t)).astype(np.float32).reshape(-1, 1)
dataset = chainer.datasets.TupleDataset(time_iterator, dataset_t)
train_iter = chainer.iterators.SerialIterator(dataset, 1)
test_iter = chainer.iterators.SerialIterator(dataset, len(dataset), repeat=False, shuffle=False)
model = Model([[-1,-1], [-1, 2], [2, 2], [2, -1]])
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
updater = chainer.training.StandardUpdater(train_iter, optimizer)
trainer = chainer.training.Trainer(updater, (50, 'epoch'))
trainer.extend(TestEvaluator(test_iter, model, trainer))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.PrintReport(
['epoch', 'loss', 'validation/loss', 'elapsed_time']
))
"""
points = model.calc(time_iterator)
x = points[...,0].data
y = points[...,1].data
plt.plot(x, y, 'b.-')
plt.plot(path_t_x, path_t_y, 'r-')
plt.show()
"""
trainer.run()
points = model.calc(time_iterator)
x = points[...,0].data
y = points[...,1].data
plt.plot(x, y, 'b.-', label='trained')
plt.plot(path_t_x, path_t_y, 'r-', label='teacher')
chainer.serializers.save_npz('model_dim3_T', model)
plt.savefig('20171212_dim3.png')
from graph_optimizer import Model, unit
import chainer
import chainer.functions as F
from chainer.serializers import load_npz
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation
import sys
import matplotlib.patches as patches
offset = [
[0,0] for i in range(4)
]
model = Model(offset)
load_npz('model_dim3_T', model)
time = np.linspace(0, 1, 100).astype(np.float32).reshape(-1, 1)
fig, ax = plt.subplots()
plt.gca().set_aspect('equal', adjustable='box')
ax.tick_params(labelbottom="off",bottom="off") # x軸の削除
ax.tick_params(labelleft="off",left="off") # y軸の削除
ax.set_xticklabels([])
plt.box("off") #枠線の削除
artists = [] # animation frames
for u in model.u:
W = u.link.data
# 縮退
W[np.where(abs(W) < 0.05)] = 0
# 衛星のほうが小さくなるように並び替え
argsort = np.argsort(-abs(W[0,:]))
omega = F.floor(u.omega).data[argsort]
phi = u.phi.data[argsort]
W = W[0, argsort]
u.link = chainer.Parameter(initializer=W, shape=(1, u.dim))
u.omega = chainer.Parameter(initializer=omega, shape=(u.dim, 1))
u.phi = chainer.Parameter(initializer=phi, shape=(u.dim,))
point = model.calc(time).data
info = np.array([u.get_info(time) for u in model.u]) # (u,dim,xy,time) [[[x[0]: shape(time,), y[0]], [x[1], y[1]]...]]
R = [u.link.data[0, :] for u in model.u]
color = 'rrrr'
line = ax.plot(point[:,0,0], point[:,0,1], color='k', alpha=0.5, linewidth=3)
for i in range(len(time)):
im = []
im.extend(line)
im.extend(ax.plot(point[i,0,0], point[i,0,1], 'ko'))
im.append(ax.add_patch(patches.ConnectionPatch(info[0,-1,:,i], info[2,-1,:,i], coordsA='data', color='r')))
im.append(ax.add_patch(patches.ConnectionPatch(info[1,-1,:,i], info[3,-1,:,i], coordsA='data', color='r')))
for ui, u in enumerate(model.u):
xbias = u.xbias.data[0]
ybias = u.ybias.data[0]
im.extend(ax.plot(xbias, ybias, color[ui]+'o'))
im.append(ax.add_patch(patches.Circle((xbias, ybias), radius=R[ui][0], fill=False, color=color[ui], linewidth=0.3)))
for dim, xy in enumerate(info[ui]):
im.extend(ax.plot(xy[0,i], xy[1,i], color[ui]+'o', markersize=1))
if dim < u.dim-1:
im.append(ax.add_patch(patches.Circle(xy[:,i], radius=R[ui][dim+1], fill=False, color=color[ui], linewidth=0.3)))
if dim == 0:
im.append(ax.add_patch(patches.ConnectionPatch((xbias, ybias), xy[:,i], coordsA='data', color='r', linewidth=0.2)))
else:
im.append(ax.add_patch(patches.ConnectionPatch(info[ui,dim-1,:,i], xy[:,i], coordsA='data', color='r', linewidth=0.2)))
artists.append(im)
print("ok")
anim = ArtistAnimation(fig, artists, interval=100)
anim.save('out.gif', writer='imagemagick')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment