Created
August 9, 2018 16:36
-
-
Save SchattenGenie/28204a1135c3b7bca06162b7b2adf073 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mpl_toolkits.mplot3d import Axes3D | |
from mpl_toolkits.mplot3d.art3d import Line3DCollection | |
def plot_3d_with_edges(X, l, edges, azim=-84, elev=10): | |
x, y, z = X.T[:3, :] | |
if X.shape[1] > 3: | |
v = X[:, 3] | |
else: | |
v = np.zeros(len(X)) | |
if len(l.shape) > 1: | |
label = np.argmax(l, axis=1) | |
else: | |
label = l | |
x_start = X[edges[:, 0], :][:, 0] | |
x_end = X[edges[:, 1], :][:, 0] | |
y_start = X[edges[:, 0], :][:, 1] | |
y_end = X[edges[:, 1], :][:, 1] | |
z_start = X[edges[:, 0], :][:, 2] | |
z_end = X[edges[:, 1], :][:, 2] | |
x_min, x_max = X[:, 0].min(), X[:, 0].max() | |
y_min, y_max = X[:, 1].min(), X[:, 1].max() | |
z_min, z_max = X[:, 2].min(), X[:, 2].max() | |
# plot only hits | |
fig = plt.figure(figsize=(12, 12)) | |
ax = fig.gca(projection='3d') | |
ax.view_init(azim=azim, elev=elev) | |
ax.scatter(x,y,z, c='k', s=2) | |
ax.set_xlabel("z") | |
ax.set_ylabel("y") | |
ax.set_zlabel("x") | |
ax.set_xlim(x_min, x_max) | |
ax.set_ylim(y_min, y_max) | |
ax.set_zlim(z_min, z_max) | |
plt.show() | |
# plot with labels | |
fig = plt.figure(figsize=(12, 12)) | |
ax = fig.gca(projection='3d') | |
ax.view_init(azim=azim, elev=elev) | |
ax.scatter(x,y,z,c=np.log(label + 1), s=2) | |
ax.set_xlabel("z") | |
ax.set_ylabel("y") | |
ax.set_zlabel("x") | |
ax.set_xlim(x_min, x_max) | |
ax.set_ylim(y_min, y_max) | |
ax.set_zlim(z_min, z_max) | |
plt.show() | |
# plot edges | |
fig = plt.figure(figsize=(12, 12)) | |
ax = fig.gca(projection='3d') | |
ax.view_init(azim=azim, elev=elev) | |
start_points = np.array([x_start, y_start, z_start]).T.reshape(-1, 3) | |
end_points = np.array([x_end, y_end, z_end]).T.reshape(-1, 3) | |
C = plt.cm.Blues(0.9) | |
lc = Line3DCollection(list(zip(start_points, end_points)), colors=C, alpha=0.007,lw=2) | |
ax.add_collection3d(lc) | |
ax.set_xlabel("z") | |
ax.set_ylabel("y") | |
ax.set_zlabel("x") | |
ax.set_xlim(x_min, x_max) | |
ax.set_ylim(y_min, y_max) | |
ax.set_zlim(z_min, z_max) | |
plt.show() | |
k = 0 | |
plot_3d_with_edges(X[k], Y[k], X_clusters_graph[k]['X_cluster_in_out']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment