Skip to content

Instantly share code, notes, and snippets.

@ytbilly3636
Created June 6, 2017 10:32
Show Gist options
  • Save ytbilly3636/8f6de6c95fc2efdc99460a292c103860 to your computer and use it in GitHub Desktop.
Save ytbilly3636/8f6de6c95fc2efdc99460a292c103860 to your computer and use it in GitHub Desktop.
水を求めるミミズ
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
class TravelPoints():
def __init__(self, num_points):
self.points = np.ones((num_points, 2), dtype=np.float32)
return
def update(self, point_city, lr, var):
index_nearest = np.argmin(np.linalg.norm(point_city - self.points, axis=1))
self.points = self.points + lr * self.coef_near(index_nearest, var).reshape(-1, 1) * (point_city - self.points)
return
def coef_near(self, index, var):
indices = np.arange(-index, self.points.shape[0] - index)
return np.exp(-(indices ** 2) / (2 * var ** 2))
points = TravelPoints(num_points=100)
cities = np.random.rand(20, 2)
for epoch in xrange(100):
print epoch
lr = 0.2 * 0.99 ** epoch
var = 1.5
for city in cities:
points.update(city, lr, var)
plt.clf()
plt.plot(cities[:,0], cities[:,1], 'bo')
plt.plot(points.points[:,0], points.points[:,1], 'r')
plt.pause(.001)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment