Skip to content

Instantly share code, notes, and snippets.

@InfProbSciX
Created February 14, 2024 12:05
Show Gist options
  • Save InfProbSciX/b8ebd9f5827849c83ad129d3f8bc0a33 to your computer and use it in GitHub Desktop.
Save InfProbSciX/b8ebd9f5827849c83ad129d3f8bc0a33 to your computer and use it in GitHub Desktop.
Random projections distance preservation
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.random_projection import GaussianRandomProjection
from sklearn.neighbors import NearestNeighbors
from scipy.stats import spearmanr
from tqdm import tqdm
np.random.seed(42)
# Function to compute distances to the nearest neighbors
def compute_nearest_neighbor_distances(X, n_neighbors=15):
nn = NearestNeighbors(n_neighbors=n_neighbors)
nn.fit(X)
distances = []
for i in range(X.shape[0]):
dists, _ = nn.kneighbors([X[i]], n_neighbors=n_neighbors+1)
distances.extend(dists[0][1:]) # exclude the first distance (zero distance to itself)
return np.array(distances)
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data / 255.0
indices = np.random.choice(len(X), 5000, replace=False)
X_subset = X.iloc[indices].to_numpy()
dist_high_nn = compute_nearest_neighbor_distances(X_subset)
dimensions = [2, 3, 10, 30, 100, 250, 500, 750]
spearman_correlations_nn = []
for dim in tqdm(dimensions):
# Project the data into the lower-dimensional space
transformer = GaussianRandomProjection(n_components=dim, random_state=42)
X_projected = transformer.fit_transform(X_subset)
# Compute distances to the nearest neighbors in the projected space
dist_low_nn = compute_nearest_neighbor_distances(X_projected)
correlation, _ = spearmanr(dist_high_nn, dist_low_nn)
spearman_correlations_nn.append(correlation)
plt.figure(figsize=(10, 6))
plt.plot(dimensions, spearman_correlations_nn, marker='o', color='green')
plt.xscale('log')
plt.xticks(dimensions, labels=dimensions)
plt.xlabel('Random Projection Dimension')
plt.ylabel('Spearman Correlation')
plt.title('MNIST NN distance correlation - high dim vs random proj.')
plt.grid(True)
plt.show()
@InfProbSciX
Copy link
Author

output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment