Skip to content

Instantly share code, notes, and snippets.

@jxcodetw
Created February 20, 2020 21:10
Show Gist options
  • Save jxcodetw/8b307c08a43c7e28c7f2d78179ca8c2b to your computer and use it in GitHub Desktop.
Save jxcodetw/8b307c08a43c7e28c7f2d78179ca8c2b to your computer and use it in GitHub Desktop.
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from openTSNE import TSNE
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import SpectralEmbedding
from scipy.sparse import save_npz, load_npz
import random
from functools import partial
from torch.nn.utils import spectral_norm
EPS = 1e-12
N_EPOCHS = 500
NEG_RATE = 5.0
BATCH_SIZE = 4096
FORCE_RETRY = False
DATASET_PATH = 'mnist.npz'
EVAL_ON_CPU = False
WHOLE_NET_GRAD_CLIP = True
def get_activation(act):
if act == 'lrelu':
return nn.LeakyReLU(0.2, inplace=True)
elif act == 'relu':
return nn.ReLU(inplace=True)
raise Exception('unsupported activation function')
class FCEncoder(nn.Module):
def __init__(self, dim, num_layers=3, act='lrelu'):
super(FCEncoder, self).__init__()
self.dim = dim
self.num_layers = num_layers
self.act = partial(get_activation, act=act)
hidden_dim = 128
layers = [
(nn.Linear(dim, hidden_dim*2)),
self.act(),
(nn.Linear(hidden_dim*2, hidden_dim)),
self.act(),
]
layers += [
(nn.Linear(hidden_dim, hidden_dim)),
self.act(),
] * num_layers
layers += [
(nn.Linear(hidden_dim, 2)),
]
self.net = nn.Sequential(*layers)
def forward(self, X):
return self.net(X)
def make_graph(P, n_epochs=-1):
graph = P.tocoo()
graph.sum_duplicates()
n_vertices = graph.shape[1]
if n_epochs <= 0:
# For smaller datasets we can use more epochs
if graph.shape[0] <= 10000:
n_epochs = 500
else:
n_epochs = 200
graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
graph.eliminate_zeros()
return graph
def make_epochs_per_sample(weights, n_epochs):
result = -1.0 * np.ones(weights.shape[0], dtype=np.float64)
n_samples = n_epochs * (weights / weights.max())
result[n_samples > 0] = float(n_epochs) / n_samples[n_samples > 0]
return result
def neg_squared_euc_dists(X):
sum_X = X.pow(2).sum(dim=1)
D = (-2 * X @ X.transpose(1, 0) + sum_X).transpose(1, 0) + sum_X
return -D
def w_tsne(Y):
distances = neg_squared_euc_dists(Y)
inv_distances = (1. - distances).pow(-1) #1 / (1+d^2)
inv_distances = inv_distances
return inv_distances
def KLD(P, Q):
return P * torch.log((P+EPS) / Q)
def CE(P, Q):
return - P * torch.log(Q + EPS) - (1 - P) * torch.log(1 - Q + EPS)
def MXLK(P, w, gamma=7.0):
return P * torch.log(w + EPS) + gamma * (1 - P) * torch.log(1 - w + EPS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('load data')
mnist = np.load(DATASET_PATH)
data = mnist['data'].astype('float')
print('calc P')
try:
if FORCE_RETRY:
raise Exception()
P_csc = load_npz('P_csc.npz')
print('Use P cache')
except:
print('Use new P')
pre_embedding = TSNE(perplexity=30).prepare_initial(data)
P_csc = pre_embedding.affinities.P
save_npz('P_csc', pre_embedding.affinities.P)
print('convert P to torch.Tensor')
P = torch.Tensor(P_csc.toarray())
diag_mask = (1 - torch.eye(P.size(0))).to(device)
print('make_graph')
graph = make_graph(P_csc, N_EPOCHS)
print('make_epochs_per_sample')
epochs_per_sample = make_epochs_per_sample(graph.data, N_EPOCHS)
print('Constructing NN')
encoder = FCEncoder(784, num_layers=10)
encoder = encoder.to(device)
encoder = encoder.float()
print('optimizing...')
P = P.to(device)
# Y = (torch.from_numpy(Y_init)).to(device).detach().requires_grad_(True)
# optimizer = optim.SGD([Y], lr=1)
X = torch.from_numpy(data)
X = X.to(device)
X = X.float()
init_lr = 1e-3
optimizer = optim.SGD(encoder.parameters(), lr=init_lr, weight_decay=0)
epochs_per_negative_sample = epochs_per_sample / NEG_RATE
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
epoch_of_next_sample = epochs_per_sample.copy()
head = graph.row
tail = graph.col
rnd_max_idx = P.shape[0] - 1
init_gamma = 7
gamma = init_gamma
losses = []
for epoch in range(N_EPOCHS):
batch_i = []
batch_j = []
batch_neg_i = []
batch_neg_j = []
for i in range(epochs_per_sample.shape[0]):
if epoch_of_next_sample[i] <= epoch:
i_idx, j_idx = head[i], tail[i]
batch_i.append(i_idx)
batch_j.append(j_idx)
epoch_of_next_sample[i] += epochs_per_sample[i]
n_neg_samples = int(
(epoch - epoch_of_next_negative_sample[i])
/ epochs_per_negative_sample[i]
)
epoch_of_next_negative_sample[i] += (
n_neg_samples * epochs_per_negative_sample[i]
)
for i in range(0, len(batch_i), BATCH_SIZE):
bi = batch_i[i:i+BATCH_SIZE]
bj = batch_j[i:i+BATCH_SIZE]
optimizer.zero_grad()
Y_bi = encoder(X[bi])
Y_bj = encoder(X[bj])
d = (Y_bi - Y_bj).pow(2).sum(dim=1)
w = (1/(1+d)).clamp(min=0, max=1)
loss = - (torch.log(w + EPS))
loss = loss.sum()
loss.backward()
if WHOLE_NET_GRAD_CLIP:
torch.nn.utils.clip_grad_value_(encoder.parameters(), 4)
optimizer.step()
for p in range(5):
bj = [random.randint(0, rnd_max_idx) for _ in range(len(bi))]
optimizer.zero_grad()
Y_bi = encoder(X[bi])
with torch.no_grad():
Y_bj = encoder(X[bj]).detach()
d = (Y_bi - Y_bj).pow(2).sum(dim=1)
w = (1/(1+d)).clamp(min=0, max=1)
loss = - (gamma * torch.log(1 - w + EPS))
loss = loss.sum()
loss.backward()
if WHOLE_NET_GRAD_CLIP:
torch.nn.utils.clip_grad_value_(encoder.parameters(), 4)
optimizer.step()
with torch.no_grad():
if EVAL_ON_CPU:
encoder = encoder.to('cpu')
Y = encoder(X.to('cpu'))
w = w_tsne(Y).clamp(min=0, max=1)
encoder = encoder.to(device)
loss = MXLK(P.to('cpu'), w).sum()
losses.append(loss.item())
else:
Y = encoder(X)
w = w_tsne(Y).clamp(min=0, max=1)
loss = MXLK(P, w).sum()
losses.append(loss.item())
for param_group in optimizer.param_groups:
param_group['lr'] = (1 - epoch / N_EPOCHS) * init_lr
# gamma = (epoch / N_EPOCHS) * init_gamma
np.savez_compressed('largevis_fast_nn_Y', Y=Y.detach().cpu().numpy())
print("{:.2f}".format(loss.item()), "{:.3f}".format(1 - epoch / N_EPOCHS), 'Saved tmp Y')
# break
np.savez_compressed('largevis_fast_nn_loss', losses=losses)
print('Done.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment