Created
December 17, 2020 10:27
-
-
Save tanacchi/d737464531676dad933c48c4e5d58add to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def gen_saddle_shape(resolution, random_seed=0, noise_scale=0.1): | |
np.random.seed(random_seed) | |
z1 = np.random.rand(resolution) * 2.0 - 1.0 | |
z2 = np.random.rand(resolution) * 2.0 - 1.0 | |
X = np.empty((resolution, 3)) | |
X[:, 0] = z1 | |
X[:, 1] = z2 | |
X[:, 2] = z1**2 - z2**2 | |
X += np.random.normal(loc=0, scale=noise_scale, size=X.shape) | |
return X | |
if __name__ == '__main__': | |
## gif を生成するコード | |
from matplotlib import pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
def update_graph(angle, X, fig, ax): | |
ax.cla() | |
ax.view_init(azim=angle, elev=30) | |
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=X[:, 0]) | |
X = gen_saddle_shape(200) | |
fig = plt.figure(figsize=(5, 5)) | |
ax = fig.add_subplot(1, 1, 1, projection='3d') | |
ani = FuncAnimation(fig, | |
update_graph, | |
frames=360, | |
interval=30, | |
repeat=True, | |
fargs=(X, fig, ax)) | |
plt.show() | |
ani.save("tmp.gif", writer='pillow') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from scipy.spatial import distance as dist | |
from data import gen_saddle_shape | |
from visualizer import visualize_history | |
sigma = 0.5 | |
kernel = lambda Z1, Z2: np.exp(-dist.cdist(Z1, Z2)**2 / (2 * sigma**2)) | |
def estimate_f(X, Z1, Z2=None): | |
Z2 = np.copy(Z1) if Z2 is None else Z2 | |
kernels = kernel(Z1, Z2) | |
R = kernels / np.sum(kernels, axis=1, keepdims=True) | |
return R @ X | |
def make_grid2d(resolution, bounds=(-1, +1)): | |
mesh, step = np.linspace(bounds[0], | |
bounds[1], | |
resolution, | |
endpoint=False, | |
retstep=True) | |
mesh += step / 2.0 | |
grid = np.meshgrid(mesh, mesh) | |
return np.dstack(grid).reshape(-1, 2) | |
with open("X.pickle", 'rb') as f: | |
X = pickle.load(f) | |
with open("Z_history.pickle", 'rb') as f: | |
Z_history = pickle.load(f) | |
resolution = 10 | |
f_history = np.zeros((Z_history.shape[0], resolution**2, 3)) | |
for i, Z in enumerate(Z_history): | |
Zeta = make_grid2d(resolution, (Z.min(), Z.max())) | |
f = estimate_f(X, Zeta, Z) | |
f_history[i] = f | |
visualize_history(X, f_history, Z_history, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from collections import OrderedDict | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.autograd import Variable | |
from torchvision import datasets, models, transforms | |
from tqdm import tqdm | |
from data import gen_saddle_shape | |
from ukr_nn import UKRNet | |
# プロセッサの設定 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# データの準備 | |
X = torch.from_numpy(gen_saddle_shape(N := 100).astype(np.float32)).to(device) | |
X_train = X.repeat(samples := 1000, 1, 1) | |
train = torch.utils.data.TensorDataset(X_train, X_train) | |
trainloader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True) | |
# モデル,学習の設定 | |
model = UKRNet(N).to(device) | |
criterion = nn.MSELoss() | |
optimizer = optim.SGD(model.parameters(), | |
lr=0.01, | |
momentum=0.9, | |
weight_decay=1e-4) | |
# 学習結果,loss 保存用の変数 | |
num_epoch = 200 | |
Z_history = np.zeros((num_epoch, N, 2)) | |
losses = [] | |
# 学習ループ | |
with tqdm(range(num_epoch)) as pbar: | |
for epoch in pbar: | |
running_loss = 0.0 | |
for i, data in enumerate(trainloader): | |
inputs, targets = data | |
inputs, targets = Variable(inputs), Variable(targets) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, targets) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
# 潜在変数の保存 | |
Z_history[epoch] = model.layer.Z.detach().cpu().numpy() | |
# loss の値の保存 | |
losses.append(running_loss) | |
# プログレスバーの表示 | |
pbar.set_postfix( | |
OrderedDict(epoch=f"{epoch + 1}", loss=f"{running_loss:.3f}")) | |
# Loss の推移の描画 | |
plt.plot(np.arange(num_epoch), np.array(losses)) | |
plt.xlabel("epoch") | |
plt.show() | |
# 学習結果を *.pickle で保存 | |
with open("./X.pickle", 'wb') as f: | |
pickle.dump(X.detach().cpu().numpy(), f) | |
with open("./Z_history.pickle", 'wb') as f: | |
pickle.dump(Z_history, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
class UKRLayer(nn.Module): | |
def __init__(self, data_num, latent_dim, sigma=1, random_seed=0): | |
super().__init__() | |
self.kernel = lambda Z1, Z2: torch.exp(-torch.cdist(Z1, Z2)**2 / | |
(2 * sigma**2)) | |
torch.manual_seed(random_seed) | |
self.Z = nn.Parameter(torch.randn(data_num, latent_dim) / 10.) | |
def forward(self, X): | |
kernels = self.kernel(self.Z, self.Z) | |
R = kernels / torch.sum(kernels, axis=1, keepdims=True) | |
Y = R @ X | |
return Y | |
class UKRNet(nn.Module): | |
def __init__(self, N, latent_dim=2, sigma=2): | |
super(UKRNet, self).__init__() | |
self.layer = UKRLayer(N, latent_dim, sigma) | |
def forward(self, x): | |
return self.layer(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
def visualize_history(X, Y_history, Z_history, save_gif=False): | |
input_dim, latent_dim = X.shape[1], Z_history[0].shape[1] | |
input_projection_type = '3d' if input_dim > 2 else 'rectilinear' | |
fig = plt.figure(figsize=(10, 5)) | |
input_ax = fig.add_subplot(1, 2, 1, projection=input_projection_type) | |
latent_ax = fig.add_subplot(1, 2, 2) | |
num_epoch = len(Y_history) | |
if input_dim == 3 and latent_dim == 2: | |
_, K, _ = Y_history.shape | |
reso = int(np.sqrt(K)) | |
Y_history = np.array(Y_history).reshape( | |
(num_epoch, reso, reso, input_dim)) | |
observable_drawer = [None, None, draw_observable_2D, | |
draw_observable_3D][input_dim] | |
latent_drawer = [None, draw_latent_1D, draw_latent_2D][latent_dim] | |
print("X: {}, Y: {}, Z:{}".format(X.shape, Y_history[0].shape, | |
Z_history[0].shape)) | |
ani = FuncAnimation(fig, | |
update_graph, | |
frames=num_epoch, | |
repeat=True, | |
interval=50, | |
fargs=(observable_drawer, latent_drawer, X, Y_history, | |
Z_history, fig, input_ax, latent_ax, num_epoch)) | |
plt.show() | |
if save_gif: | |
ani.save("tmp.gif", writer='pillow') | |
def update_graph(epoch, observable_drawer, latent_drawer, X, Y_history, | |
Z_history, fig, input_ax, latent_ax, num_epoch): | |
fig.suptitle(f"epoch: {epoch}") | |
input_ax.cla() | |
# input_ax.view_init(azim=(epoch * 400 / num_epoch), elev=30) | |
latent_ax.cla() | |
Y, Z = Y_history[epoch], Z_history[epoch] | |
colormap = X[:, 0] | |
observable_drawer(input_ax, X, Y, colormap) | |
latent_drawer(latent_ax, Z, colormap) | |
def draw_observable_3D(ax, X, Y, colormap): | |
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=colormap) | |
# ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c='b', alpha=0.5, s=30) | |
if len(Y.shape) == 3: | |
ax.plot_wireframe(Y[:, :, 0], | |
Y[:, :, 1], | |
Y[:, :, 2], | |
color='black', | |
alpha=0.5) | |
# else: | |
# ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') | |
# ax.plot(Y[:, 0], Y[:, 1], Y[:, 2], color='black') | |
# ax.plot_wireframe(Y[:, :, 0], Y[:, :, 1], Y[:, :, 2], color='black') | |
def draw_observable_2D(ax, X, Y, colormap): | |
ax.scatter(X[:, 0], X[:, 1], c=colormap) | |
ax.plot(Y[:, 0], Y[:, 1], c='k') | |
def draw_latent_2D(ax, Z, colormap): | |
ax.set_xlim(-5, 5) | |
ax.set_ylim(-5, 5) | |
ax.scatter(Z[:, 0], Z[:, 1], c=colormap) | |
def draw_latent_1D(ax, Z, colormap): | |
ax.scatter(Z, np.zeros(Z.shape), c=colormap) | |
ax.set_ylim(-1, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment