Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Last active November 11, 2019 19:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzusuzu/fdc14964eb0842efdb645d14a6964a46 to your computer and use it in GitHub Desktop.
Save suzusuzu/fdc14964eb0842efdb645d14a6964a46 to your computer and use it in GitHub Desktop.
An implementation of Gaussian Mean Shift Procedure(3d)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
def kde(data, sigma):
def f(x):
l = x.shape[0]
res = np.zeros(l)
for i in range(l):
res[i] = np.sum(gaussian_kernel(x[i] - data, sigma))
return res
return f
def gaussian_kernel(x, sigma):
return 1 / (np.sqrt(2*np.pi)*sigma) * np.exp(-np.linalg.norm(x, axis=1)/(2*(sigma**2)))
def x_update(x, xi, sigma):
return np.sum(gaussian_kernel(xi - x, sigma).reshape(-1, 1) * x, axis=0) / np.sum(gaussian_kernel(xi - x, sigma))
def gaussian_mean_shift(x, sigma, max_iter=1000):
x_ = np.copy(x)
l = x.shape[0]
history = []
for _ in range(max_iter):
x_old = np.copy(x_)
history.append(x_old)
for xi in range(l):
x_[xi] = x_update(x, x_[xi], sigma)
if np.mean(np.linalg.norm(x_ - x_old, ord=1, axis=1)) < 1e-10:
break
history = np.asarray(history)
return x_, history
np.random.seed(0)
# gaussian mixture
data = np.random.normal(0.0, 1.0, size=500).reshape(-1, 2)
tmp = np.random.normal(5.0, 1.0, size=1000).reshape(-1, 2)
data = np.vstack([data, tmp])
sigma = 0.7
k = kde(data, sigma)
max_x, x_history = gaussian_mean_shift(data, sigma)
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=65.0)
nlinspace = 100
x = np.linspace(-3, 8, nlinspace)
y = np.linspace(-3, 8, nlinspace)
X, Y = np.meshgrid(x, y)
X_ = X.reshape(-1)
Y_ = Y.reshape(-1)
XY_ = np.c_[X_, Y_]
Z = k(XY_).reshape(nlinspace, nlinspace)
surf = ax.plot_surface(X, Y, Z, cmap='hsv', antialiased=True)
scat = ax.scatter(x_history[0,:,0], x_history[0,:,1], c='black', label='solution')
frames = x_history.shape[0]
def update(i):
ii = i
x = x_history[ii]
y = k(x)
scat._offsets3d = (x[:,0], x[:,1], y)
ani = animation.FuncAnimation(fig, update, frames=frames, interval=300)
plt.legend()
plt.show()
# ani.save("3d_mean_shift.gif", writer = 'imagemagick')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment