Skip to content

Instantly share code, notes, and snippets.

@InfProbSciX
Created October 15, 2023 14:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save InfProbSciX/a9495aed39ce9aaa6215cccf33f416ea to your computer and use it in GitHub Desktop.
Save InfProbSciX/a9495aed39ce9aaa6215cccf33f416ea to your computer and use it in GitHub Desktop.
Running ProbDR in Stan
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from cmdstanpy import CmdStanModel
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from torchvision.transforms.functional import rotate
from sklearn.manifold import TSNE, SpectralEmbedding
from scipy.spatial.distance import cdist, squareform
from tensorflow.keras.datasets.mnist import load_data
from sklearn.manifold._t_sne import _joint_probabilities
plt.ion(); plt.style.use('seaborn-pastel')
np.random.seed(42)
##############################################################
# Stan model specification
with open('functions.stan', 'w') as f: f.write("""
functions {
matrix squared_distances(matrix x) {
int n = dims(x)[1];
matrix[n, n] result = rep_matrix(0.0, n, n);
for(i in 2:n) {
for (j in 1:(i - 1)) {
result[i, j] = squared_distance(x[i, ], x[j, ]);
}
}
result += result';
return result;
}
}
""")
with open('data.stan', 'w') as f: f.write("""
data {
int n; // num data
int d; // num data dims
int q; // num latents
matrix[n, n] M_hat; // psd matrix estimate that uses data
}
""")
with open('transformed_data_wishart.stan', 'w') as f: f.write("""
transformed data {
matrix[n, n] M;
real jitter = 0.0;
int rho = d;
if (n >= d) rho = n;
jitter = trace(diagonal(M_hat) * rep_vector(1.0/(n * 1e6), n)');
M = add_diag(M_hat * rho, jitter);
}
""")
with open('transformed_data_discrete.stan', 'w') as f: f.write("""
transformed data {
array[n, n] int M_int;
int rho_i = 1000000;
M_int = to_int(to_array_2d(M_hat * rho_i));
}
""")
with open('params.stan', 'w') as f: f.write("""
parameters {
matrix[n, q] X; // latents
real<lower=1e-6, upper=1> sigma_sq;
}
""")
with open('model_tsne.stan', 'w') as f: f.write("""
#include functions.stan
#include data.stan
#include transformed_data_discrete.stan
#include params.stan
transformed parameters {
matrix[n, n] W = add_diag(1 ./ (1 + squared_distances(X)), 1e-6 - 1);
W = W / sum(W);
}
model {
// flatish prior on X
for(i in 1:q) X[, i] ~ cauchy(0, 20);
to_array_1d(M_int) ~ multinomial(to_vector(W));
}
""")
with open('model_umap.stan', 'w') as f: f.write("""
#include functions.stan
#include data.stan
#include transformed_data_discrete.stan
#include params.stan
transformed parameters {
matrix[n, n] W = 1 ./ (1 + 2 * squared_distances(X));
}
model {
// flatish prior on X
for(i in 1:q) X[, i] ~ cauchy(0, 20);
for (i in 2:n) M_int[i, 1:(i - 1)] ~ binomial(rho_i, W[i, 1:(i - 1)]);
}
""")
with open('model_pca.stan', 'w') as f: f.write("""
#include functions.stan
#include data.stan
#include transformed_data_wishart.stan
#include params.stan
transformed parameters {
matrix[n, n] W = add_diag(X * X', sigma_sq);
}
model {
// flatish prior on X
for(i in 1:q) X[, i] ~ cauchy(0, 20);
M ~ wishart(rho, W);
}
""")
with open('model_le.stan', 'w') as f: f.write("""
#include functions.stan
#include data.stan
#include transformed_data_wishart.stan
#include params.stan
transformed parameters {
matrix[n, n] W = inverse(add_diag(X * X', sigma_sq));
}
model {
// flatish prior on X
for(i in 1:q) X[, i] ~ cauchy(0, 20);
M ~ wishart(rho, W);
}
""")
##############################################################
# Utils
def get_transforms(Y, algo_type):
n, d = Y.shape
if algo_type == 'tsne':
ext_model = TSNE(method='exact')
M = squareform(_joint_probabilities(cdist(Y, Y)**2, ext_model.perplexity, False))
X_ext = ext_model.fit_transform(Y)
elif algo_type == 'umap':
from umap import UMAP
from umap.umap_ import fuzzy_simplicial_set
M, _, _ = fuzzy_simplicial_set(Y, 15, 42, 'euclidean')
M = M.toarray()
X_ext = UMAP(a=2, b=1, init='random').fit_transform(Y)
elif algo_type == 'pca':
M = Y @ Y.T / d
X_ext = PCA(2, svd_solver='full').fit_transform(Y)
elif algo_type == 'le':
nbrs = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(Y)
adjacency = nbrs.kneighbors_graph(Y).toarray()
adjacency = (adjacency + adjacency.T).clip(0, 1)
M = np.diag(adjacency.sum(axis=0)) - adjacency
X_ext = SpectralEmbedding(affinity='precomputed').fit_transform(adjacency)
else:
raise NotImplementedError('Invalid algorithm.')
return np.asarray(M), X_ext
def plot(X, c, ax=None):
plot_df = pd.DataFrame(dict(x=X[:, 0], y=X[:, 1], hue=c.astype(str)))
plot_df = plot_df.set_index('hue').sort_index().reset_index()
sns.scatterplot(data=plot_df, x='x', y='y', hue='hue', palette='Spectral', ax=ax)
class TrainingScheme:
def __init__(self, model):
self.model = model
def _fit_one(self, seed, algo, **kwargs):
fit = self.model.optimize(seed=seed, algorithm=algo, **kwargs)
return fit, fit.optimized_params_dict['lp__']
def _fit_one_algo_agnos(self, seed, **kwargs):
try:
return self._fit_one(seed, 'LBFGS', **kwargs)
except RuntimeError:
return self._fit_one(seed, 'Newton', **kwargs)
def fit(self, n_runs=5, **kwargs):
fit, lp = self._fit_one_algo_agnos(42, **kwargs)
for seed in range(n_runs - 1):
fit_r, lp_r = self._fit_one_algo_agnos(seed, **kwargs)
if lp_r > lp:
fit, lp = fit_r, lp_r
return fit
def nearPD(mat, tol=1e-4):
l, u = np.linalg.eigh(mat)
ls_fixed = l < tol
print(f'Converting M to PD, fixed {sum(ls_fixed)} evals, of which ' \
f'{sum(ls_fixed[abs(l) < tol])} were near zero.')
l[l < tol] = tol
mat = u @ np.diag(l) @ u.T
mat = (mat + mat.T)/2
return mat
def norm_data(Y):
Y = Y.copy().reshape(-1, 28**2)/255
# some jitter
Y += np.random.normal(0, 0.05, size=Y.shape)
# de-mean the data as we assume zero mean in the gen. models
Y -= Y.mean(axis=0)
Y -= Y.mean(axis=1)[..., None]
return Y
if __name__ == '__main__':
##########################################################
# Data loading
(_, _), (Y, c) = load_data()
# a small dataset
Y = Y[np.isin(c, [0, 1, 7])][:10]
c = c[np.isin(c, [0, 1, 7])][:10]
Ys = []
for Yi in Y:
for angle in np.linspace(0, 350, 25):
Ys.append(rotate(torch.tensor(Yi)[None, ...], angle))
Y = torch.cat(Ys, axis=0).numpy()
Y = norm_data(Y)
c = np.repeat(c, 25)
n, d = Y.shape
##########################################################
# Data prep
algos_to_run = ['tsne', 'umap', 'pca', 'le']
fig, axs = plt.subplots(2, len(algos_to_run))
for i, algo_type in enumerate(algos_to_run):
model = CmdStanModel(stan_file=f'model_{algo_type}.stan')
M, X_ext = get_transforms(Y, algo_type)
if algo_type in ['pca', 'le']:
M = nearPD(M)
pd.Series(dict(
M_hat=M, n=n, d=Y.shape[1],
q=(3 if algo_type == 'le' else 2) # as le drops a minor eigenvec
)).to_json('data.json')
##########################################################
# Inference
kwargs = dict(data='data.json', show_console=True, iter=int(1e5))
fit = TrainingScheme(model).fit(n_runs=3, **kwargs)
X = fit.stan_variable('X')
idxs = X.std(axis=0)
idxs = idxs.argsort()[-2:]
X = X[:, idxs]
##########################################################
# Results
axs[0, i].set_title(algo_type)
plot(X, c, axs[0, i])
plot(X_ext, c, axs[1, i])
for ax in axs.reshape(-1):
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
for ax in axs[:, :-1].reshape(-1):
ax.legend([],[], frameon=False)
axs[0, 0].set_ylabel('map using stan')
axs[1, 0].set_ylabel('external implementation')
plt.savefig('stan.pdf')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment