Skip to content

Instantly share code, notes, and snippets.

@keimina
Last active July 10, 2021 04:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keimina/e73bf88714b20bc793afd1755c481293 to your computer and use it in GitHub Desktop.
Save keimina/e73bf88714b20bc793afd1755c481293 to your computer and use it in GitHub Desktop.
import pandas as pd
import matplotlib
matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
import io
import numpy as np
import torch
from itertools import product
# random のシードを設定する
np.random.seed(0)
def compute_pdistance_normalize_sym(values):
d_data = []
for (i, u), (j, v) in product(enumerate(values), enumerate(values)):
# Step 1. それぞれの2点間の距離を求め、それを行列で表現する
r = u - v
d = r.matmul(r)
d = torch.sqrt(d)
# Step 2. Normal Distribution 上での値を求める
pdistance = torch.exp(-d**2/2)/np.sqrt(2*np.pi)
# Step 3. 自分自身との距離は0にする(行列の対角を0にする)
if i==j:
pdistance = 0.0
d_data.append(pdistance)
# Step 4. 行列にする
pdistance = np.array(d_data, dtype=object) # 行列を torch.tensor ではなく np.array で dtype=object で作成する。torch.tensor で作成するとtensorのコピーが作成され、元のtensorの参照でなくなる現象が発生するため。
pdistance = pdistance.reshape(len(values), len(values))
# Step 5. 行の総和が1になるように正規化する
pdistance_normalize = pdistance / pdistance.sum(1, keepdims=True)
# Step 6. 転置行列を足して平均をとる
pdistance_normalize_sym = (pdistance_normalize + pdistance_normalize.T)/2
return pdistance_normalize_sym
# 今回扱うデータを作成する
file = io.StringIO("""
x,y,color
3,1,r
2,2,r
-1,3,g
-2,2,g
-3,-1,b
-2,-2,b
""".strip())
# データ読み込み
df = pd.read_csv(file)
# 教師データを作成する
values = torch.tensor(df[["x", "y"]].values.astype(np.float32))
t = compute_pdistance_normalize_sym(values)
##########################################################################
# 学習フェーズ
##########################################################################
# Step 1. 学習するパラメータの初期化
o_param = np.random.choice(range(len(df)), len(df), False)
o_param = torch.tensor(o_param.astype(np.float32), requires_grad=True)
o_param = o_param.reshape(-1,1)
o_param.retain_grad() # この1行が超重要、ないとエラーになる
print("学習前:", o_param)
# Step 3. 学習
lr = 0.03
for epoch in range(100):
o_param.grad = None
# 学習するパラメータで2点間の距離を求める
o = compute_pdistance_normalize_sym(o_param)
# Step 3. 教師データと、学習パラメータのKLダイバージェンスをコスト関数とする
# KLダイバージェンス: "https://ja.wikipedia.org/wiki/カルバック・ライブラー情報量" の式を使用
losses = []
for i in range(o.shape[0]):
for j in range(o.shape[1]):
if i==j:
# 対角成分は 0 のため log(0) の計算を行うのを避ける
continue
else:
p = o[i,j] # 厳密には 6で割る必要があると思うが、勾配を求めるのが目的のため省略
q = t[i,j]
losses.append(-p*torch.log(q) + p*torch.log(p))
# 右の式を使用 → "https://ja.wikipedia.org/wiki/カルバック・ライブラー情報量#交差エントロピーとの関係"
loss = sum(losses)
print("loss:", loss)
loss.backward()
o_param.data = o_param.data - lr*o_param.grad.data
print("学習後", o_param)
# 可視化する
df2 = pd.DataFrame(
torch.cat([o_param, torch.zeros_like(o_param)], axis=1).detach().numpy(),
columns=["x", "y"])
df2["color"] = list("rrggbb")
fig, axes = plt.subplots(4, 1, gridspec_kw={'height_ratios': [6, 1, 1, 1]}) # https://stackoverflow.com/questions/10388462/matplotlib-different-size-subplots
df.plot.scatter("x", "y", c="color", ax=axes[0])
df2.plot.scatter("x", "y", c="color", ax=axes[1])
axes[0].set_title("Input")
axes[1].set_title("My Output")
# 検証用 1
from sklearn.manifold import TSNE
X_embedded = TSNE(n_components=1).fit_transform(df.loc[:,["x", "y"]])
df3 = pd.DataFrame(
np.concatenate([X_embedded, np.zeros_like(X_embedded)], axis=1),
columns=["x", "y"])
df3["color"] = list("rrggbb")
df3.plot.scatter("x", "y", c="color", ax=axes[2])
axes[2].set_title("Sklearn TSNE Output")
# 検証用 2
from sklearn.decomposition import PCA
X_embedded = PCA(n_components=1).fit_transform(df.loc[:,["x", "y"]])
df4 = pd.DataFrame(
np.concatenate([X_embedded, np.zeros_like(X_embedded)], axis=1),
columns=["x", "y"])
df4["color"] = list("rrggbb")
df4.plot.scatter("x", "y", c="color", ax=axes[3])
axes[3].set_title("Sklearn PCA Output")
#####################
# matplotlib 職人用 #
#####################
# yaxis を見えなくする
axes[1].yaxis.set_visible(False)
axes[2].yaxis.set_visible(False)
axes[3].yaxis.set_visible(False)
# xaxis のラベルを消す
axes[1].set_xlabel("")
axes[2].set_xlabel("")
axes[3].set_xlabel("")
# レイアウトをつめる
fig.tight_layout()
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment