Last active
April 23, 2017 04:26
-
-
Save kskkwn/3faac8748640744cb7ccf3c4254bc195 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# parameters ###### 調整したほうが良いかも | |
nb_k = 8 | |
α = 6 | |
a0 = b0 = 0.5 | |
##################### | |
import numpy as np | |
from numpy import exp | |
from scipy.special import loggamma as logΓ | |
from numpy.random import choice | |
m = lambda z: z.sum(axis=0) | |
α1 = α2 = np.ones(nb_k) * α | |
def onehot(i, nb_k): | |
ret = np.zeros(nb_k) | |
ret[i] = 1 | |
return ret | |
def update_z1ᵢ(X, z1, z2, i): | |
N1, N2 = X.shape | |
m1 = m(z1) | |
m2 = m(z2) | |
n_pos = np.einsum("ikjl, ij", np.tensordot(z1, z2, axes=0), X) # n_pos_kl = n_pos[k][l] | |
n_neg = np.einsum("ikjl, ij", np.tensordot(z1, z2, axes=0), 1 - X) | |
# hatつきはi番目 | |
m1_hat = lambda i: m1 - z1[i] # m1_hat_k = m1_hat[k] | |
n_pos_hat = lambda i: n_pos - np.einsum("kjl, j", np.tensordot(z1, z2, axes=0)[i], X[i]) | |
n_neg_hat = lambda i: n_neg - np.einsum("kjl, j", np.tensordot(z1, z2, axes=0)[i], 1 - X[i]) | |
α_1_hat = lambda i: α1 + m1_hat(i) | |
a_hat = lambda i: a0 + n_pos_hat(i) | |
b_hat = lambda i: b0 + n_neg_hat(i) | |
aᵢhat = a_hat(i) | |
bᵢhat = b_hat(i) | |
p_z1ᵢ_left = logΓ(aᵢhat + bᵢhat) - logΓ(aᵢhat) - logΓ(bᵢhat) | |
p_z1ᵢ_right_upper = logΓ(aᵢhat + np.dot(X[i], z2)) + logΓ(bᵢhat + np.dot((1 - X[i]), z2)) | |
p_z1ᵢ_right_lower = logΓ(aᵢhat + bᵢhat + m2) | |
p_z1ᵢ = (α_1_hat(i) * exp(p_z1ᵢ_left + p_z1ᵢ_right_upper - p_z1ᵢ_right_lower)).prod(axis=1) | |
p_z1ᵢ = p_z1ᵢ.real | |
p_z1ᵢ = p_z1ᵢ / p_z1ᵢ.sum() | |
return onehot(choice(range(nb_k), p=p_z1ᵢ), nb_k) | |
def update_z2ⱼ(X, z1, z2, j): | |
N1, N2 = X.shape | |
m1 = m(z1) | |
m2 = m(z2) | |
n_pos = np.einsum("ikjl, ij", np.tensordot(z1, z2, axes=0), X) | |
n_neg = np.einsum("ikjl, ij", np.tensordot(z1, z2, axes=0), 1 - X) | |
# hatつきはi番目 | |
m2_hat = lambda j: m2 - z2[j] # m1_hat_k = m1_hat[k] | |
n_pos_hat = lambda j: n_pos - np.einsum("ikl, i", np.tensordot(z1, z2, axes=0)[..., j, :], X[:, j]) | |
n_neg_hat = lambda j: n_neg - np.einsum("ikl, i", np.tensordot(z1, z2, axes=0)[..., j, :], 1 - X[:, j]) | |
α_2_hat = lambda j: α2 + m2_hat(j) | |
a_hat = lambda j: a0 + n_pos_hat(j) | |
b_hat = lambda j: b0 + n_neg_hat(j) | |
aⱼhat = a_hat(j) | |
bⱼhat = b_hat(j) | |
p_z2ⱼ_left = logΓ(aⱼhat + bⱼhat) - logΓ(aⱼhat) - logΓ(bⱼhat) | |
p_z2ⱼ_right_upper = logΓ(aⱼhat + np.dot(X[:, j], z1)) + logΓ(bⱼhat + np.dot((1 - X[:, j]), z1)) | |
p_z2ⱼ_right_lower = logΓ(aⱼhat + bⱼhat + m1) | |
p_z2ⱼ = (α_2_hat(j) * exp(p_z2ⱼ_left + p_z2ⱼ_right_upper - p_z2ⱼ_right_lower)).prod(axis=1) | |
p_z2ⱼ = p_z2ⱼ.real | |
p_z2ⱼ = p_z2ⱼ / p_z2ⱼ.sum() | |
return onehot(choice(range(nb_k), p=p_z2ⱼ), nb_k) | |
if __name__ == '__main__': | |
nb_sample_steps = 10000 | |
nb_burnin_steps = 1000 | |
samples_pkl_file = "./sample_z.pkl" | |
import pandas as pd | |
import tqdm | |
import pickle | |
import os | |
data = pd.read_csv("./combinationTable.csv") | |
uname = data[:1].get_values()[0] | |
data.drop(0) | |
X = (data.get_values()[1:] == "True").astype(int) | |
N1, N2 = X.shape | |
if not os.path.exists(samples_pkl_file): | |
z1 = np.zeros((N1, nb_k)) | |
z1[:, 0] = 1 | |
z2 = np.zeros((N2, nb_k)) | |
z2[:, 0] = 1 | |
samples_z1 = [] | |
samples_z2 = [] | |
for step in tqdm.tqdm(range(nb_burnin_steps)): | |
for i in range(N1): | |
z1[i] = update_z1ᵢ(X, z1, z2, i) | |
for j in range(N2): | |
z2[j] = update_z2ⱼ(X, z1, z2, j) | |
else: | |
with open(samples_pkl_file, "rb") as f: | |
samples_z1, samples_z2 = pickle.load(f) | |
z1 = np.array([onehot(i, nb_k) for i in samples_z1[-1]]) | |
z2 = np.array([onehot(i, nb_k) for i in samples_z2[-1]]) | |
for step in tqdm.tqdm(range(nb_sample_steps)): | |
for i in range(N1): | |
z1[i] = update_z1ᵢ(X, z1, z2, i) | |
for j in range(N2): | |
z2[j] = update_z2ⱼ(X, z1, z2, j) | |
if (step % 10) == 0: | |
samples_z1.append(np.argmax(z1, axis=1)) | |
samples_z2.append(np.argmax(z2, axis=1)) | |
if (step % 100) == 0: | |
with open("./sample_z.pkl", "wb") as f: | |
pickle.dump([samples_z1, samples_z2], f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment