Skip to content

Instantly share code, notes, and snippets.

@jxcodetw
Created February 21, 2020 01:03
Show Gist options
  • Save jxcodetw/7af7354382f7965894a134b88b7b5144 to your computer and use it in GitHub Desktop.
Save jxcodetw/7af7354382f7965894a134b88b7b5144 to your computer and use it in GitHub Desktop.
import torch
import torch.optim as optim
import torch.nn.functional as F
from openTSNE import TSNE
from umap.umap_ import fuzzy_simplicial_set, find_ab_params
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import SpectralEmbedding
from scipy.sparse import save_npz, load_npz
import random
MIN_DIST=0.1
SPREAD=1.0
EPS = 1e-12
N_EPOCHS = 500
NEG_RATE = 5.0
BATCH_SIZE = 4096
def make_graph(P, n_epochs=-1):
graph = P.tocoo()
graph.sum_duplicates()
n_vertices = graph.shape[1]
if n_epochs <= 0:
# For smaller datasets we can use more epochs
if graph.shape[0] <= 10000:
n_epochs = 500
else:
n_epochs = 200
graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
graph.eliminate_zeros()
return graph
def make_epochs_per_sample(weights, n_epochs):
result = -1.0 * np.ones(weights.shape[0], dtype=np.float64)
n_samples = n_epochs * (weights / weights.max())
result[n_samples > 0] = float(n_epochs) / n_samples[n_samples > 0]
return result
def neg_squared_euc_dists(X):
sum_X = X.pow(2).sum(dim=1)
D = (-2 * X @ X.transpose(1, 0) + sum_X).transpose(1, 0) + sum_X
return -D
def w_tsne(Y):
distances = neg_squared_euc_dists(Y)
inv_distances = (1. - distances).pow(-1) #1 / (1+d^2)
inv_distances = inv_distances
return inv_distances
def KLD(P, Q):
return P * torch.log((P+EPS) / Q)
def CE(V, W):
return - V * torch.log(W + EPS) - (1 - V) * torch.log(1 - W + EPS)
def MXLK(P, w, gamma=7.0):
return P * torch.log(w + EPS) + gamma * (1 - P) * torch.log(1 - w + EPS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('load data')
mnist = np.load('mnist.npz')
data = mnist['data']
print('estimate a, b')
ua, ub = find_ab_params(SPREAD, MIN_DIST)
# ua, ub = 1.0, 1.0
print('a:', ua, 'b:', ub)
print('calc V')
try:
V_csc = load_npz('V_csc.npz')
print('Use V cache')
except:
print('Use new V')
V_csc = fuzzy_simplicial_set(data, n_neighbors=15,
random_state=np.random.RandomState(42), metric='euclidean')
save_npz('V_csc', V_csc)
V = torch.Tensor(V_csc.toarray())
diag_mask = (1 - torch.eye(V.size(0))).to(device)
print('make_graph')
graph = make_graph(V_csc, N_EPOCHS)
print('make_epochs_per_sample')
epochs_per_sample = make_epochs_per_sample(graph.data, N_EPOCHS)
INIT_METHOD = 'random'
if INIT_METHOD == 'random':
print('Random init Y')
Y_init = np.random.randn(V.shape[0], 2) * 10
elif INIT_METHOD == 'spectral':
print('Spectral init Y')
try:
Y_init = np.load('umap_Y_init.npz')['Y']
print('use cache')
except:
print('new spectral init')
model = SpectralEmbedding(n_components = 2, n_neighbors = 50)
Y_init = model.fit_transform(data) * 10000
np.savez_compressed('umap_Y_init', Y=Y_init)
else:
print('Unknown init method:', INIT_METHOD)
print('optimizing...')
Y = (torch.from_numpy(Y_init)).to(device).detach().requires_grad_(True)
V = V.to(device)
optimizer = optim.SGD([Y], lr=1)
epochs_per_negative_sample = epochs_per_sample / NEG_RATE
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
epoch_of_next_sample = epochs_per_sample.copy()
head = graph.row
tail = graph.col
rnd_max_idx = V.shape[0] - 1
for epoch in range(N_EPOCHS):
batch_i = []
batch_j = []
batch_neg_i = []
batch_neg_j = []
for i in range(epochs_per_sample.shape[0]):
if epoch_of_next_sample[i] <= epoch:
i_idx, j_idx = head[i], tail[i]
batch_i.append(i_idx)
batch_j.append(j_idx)
epoch_of_next_sample[i] += epochs_per_sample[i]
n_neg_samples = int(
(epoch - epoch_of_next_negative_sample[i])
/ epochs_per_negative_sample[i]
)
epoch_of_next_negative_sample[i] += (
n_neg_samples * epochs_per_negative_sample[i]
)
for i in range(0, len(batch_i), BATCH_SIZE):
bi = batch_i[i:i+BATCH_SIZE]
bj = batch_j[i:i+BATCH_SIZE]
optimizer.zero_grad()
ydiff = (Y[bi].detach() - Y[bj].detach())
d = ydiff.pow(2).sum(dim=1, keepdim=True)
coeff = (2*ua*ub*d.pow(ub-1)) / (1+d)
grad = coeff * ydiff
Y[bi].backward(grad)
Y[bj].backward(-grad)
torch.nn.utils.clip_grad_value_([Y], 4)
optimizer.step()
for p in range(5):
bj = [random.randint(0, rnd_max_idx) for _ in range(len(bi))]
optimizer.zero_grad()
ydiff = (Y[bi].detach() - Y[bj].detach())
d = ydiff.pow(2).sum(dim=1, keepdim=True)
coeff = (-ub) / ((1e-3 + d)*(1+d))
grad = coeff * ydiff
Y[bi].backward(grad)
torch.nn.utils.clip_grad_value_([Y], 4)
optimizer.step()
with torch.no_grad():
w = w_tsne(Y.detach()).clamp(min=0, max=1)
loss = CE(V, w).sum()
for param_group in optimizer.param_groups:
param_group['lr'] = 1 - epoch / N_EPOCHS
np.savez_compressed('umap_fast_Y', Y=Y.detach().cpu().numpy())
print("{:.2f}".format(loss.item()), "{:.3f}".format(1 - epoch / N_EPOCHS), 'Saved tmp Y')
print('Done.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment