Created
February 14, 2024 12:05
-
-
Save InfProbSciX/b8ebd9f5827849c83ad129d3f8bc0a33 to your computer and use it in GitHub Desktop.
Random projections distance preservation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import fetch_openml | |
from sklearn.random_projection import GaussianRandomProjection | |
from sklearn.neighbors import NearestNeighbors | |
from scipy.stats import spearmanr | |
from tqdm import tqdm | |
np.random.seed(42) | |
# Function to compute distances to the nearest neighbors | |
def compute_nearest_neighbor_distances(X, n_neighbors=15): | |
nn = NearestNeighbors(n_neighbors=n_neighbors) | |
nn.fit(X) | |
distances = [] | |
for i in range(X.shape[0]): | |
dists, _ = nn.kneighbors([X[i]], n_neighbors=n_neighbors+1) | |
distances.extend(dists[0][1:]) # exclude the first distance (zero distance to itself) | |
return np.array(distances) | |
mnist = fetch_openml('mnist_784', version=1) | |
X = mnist.data / 255.0 | |
indices = np.random.choice(len(X), 5000, replace=False) | |
X_subset = X.iloc[indices].to_numpy() | |
dist_high_nn = compute_nearest_neighbor_distances(X_subset) | |
dimensions = [2, 3, 10, 30, 100, 250, 500, 750] | |
spearman_correlations_nn = [] | |
for dim in tqdm(dimensions): | |
# Project the data into the lower-dimensional space | |
transformer = GaussianRandomProjection(n_components=dim, random_state=42) | |
X_projected = transformer.fit_transform(X_subset) | |
# Compute distances to the nearest neighbors in the projected space | |
dist_low_nn = compute_nearest_neighbor_distances(X_projected) | |
correlation, _ = spearmanr(dist_high_nn, dist_low_nn) | |
spearman_correlations_nn.append(correlation) | |
plt.figure(figsize=(10, 6)) | |
plt.plot(dimensions, spearman_correlations_nn, marker='o', color='green') | |
plt.xscale('log') | |
plt.xticks(dimensions, labels=dimensions) | |
plt.xlabel('Random Projection Dimension') | |
plt.ylabel('Spearman Correlation') | |
plt.title('MNIST NN distance correlation - high dim vs random proj.') | |
plt.grid(True) | |
plt.show() |
Author
InfProbSciX
commented
Feb 14, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment