Created
April 26, 2017 15:08
-
-
Save usmcamp0811/416c7e9f6ca0460996475b6759d676be to your computer and use it in GitHub Desktop.
Simple plot for plotting the latent space of a VAE... could be used for anything because it has a TSNE function in there to reduce dimensionality to appropriate plotting size.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.manifold import TSNE | |
from mpl_toolkits.mplot3d import Axes3D | |
%matplotlib qt | |
from IPython import display | |
import matplotlib.cm as cmx | |
import matplotlib.colors as colors | |
def get_cmap(N): | |
'''Returns a function that maps each index in 0, 1, ... N-1 to a distinct | |
RGB color.''' | |
color_norm = colors.Normalize(vmin=0, vmax=N-1) | |
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv') | |
def map_index_to_rgb_color(index): | |
return scalar_map.to_rgba(index) | |
return map_index_to_rgb_color | |
def plot_latent_space(x_batch, y_batch, iteration=None, dim=2): | |
model = TSNE(n_components=dim, random_state=0, perplexity=50, learning_rate=500, n_iter=200) | |
z_mu = model.fit_transform(mu.eval(feed_dict={X: x_batch})) | |
n_classes = len(list(set(np.argmax(y_batch, 1)))) | |
cmap = get_cmap(n_classes) | |
fig = plt.figure(2, figsize=(8,8)) | |
if dim is 3: | |
for i in list(set(np.argmax(y_batch, 1))): | |
bx = fig.add_subplot(111, projection='3d') | |
index = np.where(np.argmax(y_batch, 1) == i) | |
xs = z_mu[index, 0] | |
ys = z_mu[index, 1:] | |
zs = z_mu[index, 2] | |
bx.scatter(xs, ys, zs,c=cmap(i), label=str(i)) | |
else: | |
for i in list(set(np.argmax(y_batch, 1))): | |
bx = fig.add_subplot(111) | |
index = np.where(np.argmax(y_batch, 1) == i) | |
xs = z_mu[index, 0] | |
ys = z_mu[index, 1] | |
bx.scatter(xs, ys, c=cmap(i), label=str(i)) | |
bx.set_xlabel('X Label') | |
bx.set_ylabel('Y Label') | |
bx.legend() | |
bx.set_title('Truth') | |
if iteration is None: | |
plt.savefig('latent_space.png') | |
else: | |
plt.savefig('latent_space' + str(iteration) + '.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment