Created
October 15, 2023 14:00
-
-
Save InfProbSciX/a9495aed39ce9aaa6215cccf33f416ea to your computer and use it in GitHub Desktop.
Running ProbDR in Stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from cmdstanpy import CmdStanModel | |
from sklearn.decomposition import PCA | |
from sklearn.neighbors import NearestNeighbors | |
from torchvision.transforms.functional import rotate | |
from sklearn.manifold import TSNE, SpectralEmbedding | |
from scipy.spatial.distance import cdist, squareform | |
from tensorflow.keras.datasets.mnist import load_data | |
from sklearn.manifold._t_sne import _joint_probabilities | |
plt.ion(); plt.style.use('seaborn-pastel') | |
np.random.seed(42) | |
############################################################## | |
# Stan model specification | |
with open('functions.stan', 'w') as f: f.write(""" | |
functions { | |
matrix squared_distances(matrix x) { | |
int n = dims(x)[1]; | |
matrix[n, n] result = rep_matrix(0.0, n, n); | |
for(i in 2:n) { | |
for (j in 1:(i - 1)) { | |
result[i, j] = squared_distance(x[i, ], x[j, ]); | |
} | |
} | |
result += result'; | |
return result; | |
} | |
} | |
""") | |
with open('data.stan', 'w') as f: f.write(""" | |
data { | |
int n; // num data | |
int d; // num data dims | |
int q; // num latents | |
matrix[n, n] M_hat; // psd matrix estimate that uses data | |
} | |
""") | |
with open('transformed_data_wishart.stan', 'w') as f: f.write(""" | |
transformed data { | |
matrix[n, n] M; | |
real jitter = 0.0; | |
int rho = d; | |
if (n >= d) rho = n; | |
jitter = trace(diagonal(M_hat) * rep_vector(1.0/(n * 1e6), n)'); | |
M = add_diag(M_hat * rho, jitter); | |
} | |
""") | |
with open('transformed_data_discrete.stan', 'w') as f: f.write(""" | |
transformed data { | |
array[n, n] int M_int; | |
int rho_i = 1000000; | |
M_int = to_int(to_array_2d(M_hat * rho_i)); | |
} | |
""") | |
with open('params.stan', 'w') as f: f.write(""" | |
parameters { | |
matrix[n, q] X; // latents | |
real<lower=1e-6, upper=1> sigma_sq; | |
} | |
""") | |
with open('model_tsne.stan', 'w') as f: f.write(""" | |
#include functions.stan | |
#include data.stan | |
#include transformed_data_discrete.stan | |
#include params.stan | |
transformed parameters { | |
matrix[n, n] W = add_diag(1 ./ (1 + squared_distances(X)), 1e-6 - 1); | |
W = W / sum(W); | |
} | |
model { | |
// flatish prior on X | |
for(i in 1:q) X[, i] ~ cauchy(0, 20); | |
to_array_1d(M_int) ~ multinomial(to_vector(W)); | |
} | |
""") | |
with open('model_umap.stan', 'w') as f: f.write(""" | |
#include functions.stan | |
#include data.stan | |
#include transformed_data_discrete.stan | |
#include params.stan | |
transformed parameters { | |
matrix[n, n] W = 1 ./ (1 + 2 * squared_distances(X)); | |
} | |
model { | |
// flatish prior on X | |
for(i in 1:q) X[, i] ~ cauchy(0, 20); | |
for (i in 2:n) M_int[i, 1:(i - 1)] ~ binomial(rho_i, W[i, 1:(i - 1)]); | |
} | |
""") | |
with open('model_pca.stan', 'w') as f: f.write(""" | |
#include functions.stan | |
#include data.stan | |
#include transformed_data_wishart.stan | |
#include params.stan | |
transformed parameters { | |
matrix[n, n] W = add_diag(X * X', sigma_sq); | |
} | |
model { | |
// flatish prior on X | |
for(i in 1:q) X[, i] ~ cauchy(0, 20); | |
M ~ wishart(rho, W); | |
} | |
""") | |
with open('model_le.stan', 'w') as f: f.write(""" | |
#include functions.stan | |
#include data.stan | |
#include transformed_data_wishart.stan | |
#include params.stan | |
transformed parameters { | |
matrix[n, n] W = inverse(add_diag(X * X', sigma_sq)); | |
} | |
model { | |
// flatish prior on X | |
for(i in 1:q) X[, i] ~ cauchy(0, 20); | |
M ~ wishart(rho, W); | |
} | |
""") | |
############################################################## | |
# Utils | |
def get_transforms(Y, algo_type): | |
n, d = Y.shape | |
if algo_type == 'tsne': | |
ext_model = TSNE(method='exact') | |
M = squareform(_joint_probabilities(cdist(Y, Y)**2, ext_model.perplexity, False)) | |
X_ext = ext_model.fit_transform(Y) | |
elif algo_type == 'umap': | |
from umap import UMAP | |
from umap.umap_ import fuzzy_simplicial_set | |
M, _, _ = fuzzy_simplicial_set(Y, 15, 42, 'euclidean') | |
M = M.toarray() | |
X_ext = UMAP(a=2, b=1, init='random').fit_transform(Y) | |
elif algo_type == 'pca': | |
M = Y @ Y.T / d | |
X_ext = PCA(2, svd_solver='full').fit_transform(Y) | |
elif algo_type == 'le': | |
nbrs = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(Y) | |
adjacency = nbrs.kneighbors_graph(Y).toarray() | |
adjacency = (adjacency + adjacency.T).clip(0, 1) | |
M = np.diag(adjacency.sum(axis=0)) - adjacency | |
X_ext = SpectralEmbedding(affinity='precomputed').fit_transform(adjacency) | |
else: | |
raise NotImplementedError('Invalid algorithm.') | |
return np.asarray(M), X_ext | |
def plot(X, c, ax=None): | |
plot_df = pd.DataFrame(dict(x=X[:, 0], y=X[:, 1], hue=c.astype(str))) | |
plot_df = plot_df.set_index('hue').sort_index().reset_index() | |
sns.scatterplot(data=plot_df, x='x', y='y', hue='hue', palette='Spectral', ax=ax) | |
class TrainingScheme: | |
def __init__(self, model): | |
self.model = model | |
def _fit_one(self, seed, algo, **kwargs): | |
fit = self.model.optimize(seed=seed, algorithm=algo, **kwargs) | |
return fit, fit.optimized_params_dict['lp__'] | |
def _fit_one_algo_agnos(self, seed, **kwargs): | |
try: | |
return self._fit_one(seed, 'LBFGS', **kwargs) | |
except RuntimeError: | |
return self._fit_one(seed, 'Newton', **kwargs) | |
def fit(self, n_runs=5, **kwargs): | |
fit, lp = self._fit_one_algo_agnos(42, **kwargs) | |
for seed in range(n_runs - 1): | |
fit_r, lp_r = self._fit_one_algo_agnos(seed, **kwargs) | |
if lp_r > lp: | |
fit, lp = fit_r, lp_r | |
return fit | |
def nearPD(mat, tol=1e-4): | |
l, u = np.linalg.eigh(mat) | |
ls_fixed = l < tol | |
print(f'Converting M to PD, fixed {sum(ls_fixed)} evals, of which ' \ | |
f'{sum(ls_fixed[abs(l) < tol])} were near zero.') | |
l[l < tol] = tol | |
mat = u @ np.diag(l) @ u.T | |
mat = (mat + mat.T)/2 | |
return mat | |
def norm_data(Y): | |
Y = Y.copy().reshape(-1, 28**2)/255 | |
# some jitter | |
Y += np.random.normal(0, 0.05, size=Y.shape) | |
# de-mean the data as we assume zero mean in the gen. models | |
Y -= Y.mean(axis=0) | |
Y -= Y.mean(axis=1)[..., None] | |
return Y | |
if __name__ == '__main__': | |
########################################################## | |
# Data loading | |
(_, _), (Y, c) = load_data() | |
# a small dataset | |
Y = Y[np.isin(c, [0, 1, 7])][:10] | |
c = c[np.isin(c, [0, 1, 7])][:10] | |
Ys = [] | |
for Yi in Y: | |
for angle in np.linspace(0, 350, 25): | |
Ys.append(rotate(torch.tensor(Yi)[None, ...], angle)) | |
Y = torch.cat(Ys, axis=0).numpy() | |
Y = norm_data(Y) | |
c = np.repeat(c, 25) | |
n, d = Y.shape | |
########################################################## | |
# Data prep | |
algos_to_run = ['tsne', 'umap', 'pca', 'le'] | |
fig, axs = plt.subplots(2, len(algos_to_run)) | |
for i, algo_type in enumerate(algos_to_run): | |
model = CmdStanModel(stan_file=f'model_{algo_type}.stan') | |
M, X_ext = get_transforms(Y, algo_type) | |
if algo_type in ['pca', 'le']: | |
M = nearPD(M) | |
pd.Series(dict( | |
M_hat=M, n=n, d=Y.shape[1], | |
q=(3 if algo_type == 'le' else 2) # as le drops a minor eigenvec | |
)).to_json('data.json') | |
########################################################## | |
# Inference | |
kwargs = dict(data='data.json', show_console=True, iter=int(1e5)) | |
fit = TrainingScheme(model).fit(n_runs=3, **kwargs) | |
X = fit.stan_variable('X') | |
idxs = X.std(axis=0) | |
idxs = idxs.argsort()[-2:] | |
X = X[:, idxs] | |
########################################################## | |
# Results | |
axs[0, i].set_title(algo_type) | |
plot(X, c, axs[0, i]) | |
plot(X_ext, c, axs[1, i]) | |
for ax in axs.reshape(-1): | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xlabel('') | |
ax.set_ylabel('') | |
for ax in axs[:, :-1].reshape(-1): | |
ax.legend([],[], frameon=False) | |
axs[0, 0].set_ylabel('map using stan') | |
axs[1, 0].set_ylabel('external implementation') | |
plt.savefig('stan.pdf') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment