Skip to content

Instantly share code, notes, and snippets.

@SchattenGenie
Created August 9, 2018 16:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SchattenGenie/28204a1135c3b7bca06162b7b2adf073 to your computer and use it in GitHub Desktop.
Save SchattenGenie/28204a1135c3b7bca06162b7b2adf073 to your computer and use it in GitHub Desktop.
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection
def plot_3d_with_edges(X, l, edges, azim=-84, elev=10):
x, y, z = X.T[:3, :]
if X.shape[1] > 3:
v = X[:, 3]
else:
v = np.zeros(len(X))
if len(l.shape) > 1:
label = np.argmax(l, axis=1)
else:
label = l
x_start = X[edges[:, 0], :][:, 0]
x_end = X[edges[:, 1], :][:, 0]
y_start = X[edges[:, 0], :][:, 1]
y_end = X[edges[:, 1], :][:, 1]
z_start = X[edges[:, 0], :][:, 2]
z_end = X[edges[:, 1], :][:, 2]
x_min, x_max = X[:, 0].min(), X[:, 0].max()
y_min, y_max = X[:, 1].min(), X[:, 1].max()
z_min, z_max = X[:, 2].min(), X[:, 2].max()
# plot only hits
fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
ax.view_init(azim=azim, elev=elev)
ax.scatter(x,y,z, c='k', s=2)
ax.set_xlabel("z")
ax.set_ylabel("y")
ax.set_zlabel("x")
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
plt.show()
# plot with labels
fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
ax.view_init(azim=azim, elev=elev)
ax.scatter(x,y,z,c=np.log(label + 1), s=2)
ax.set_xlabel("z")
ax.set_ylabel("y")
ax.set_zlabel("x")
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
plt.show()
# plot edges
fig = plt.figure(figsize=(12, 12))
ax = fig.gca(projection='3d')
ax.view_init(azim=azim, elev=elev)
start_points = np.array([x_start, y_start, z_start]).T.reshape(-1, 3)
end_points = np.array([x_end, y_end, z_end]).T.reshape(-1, 3)
C = plt.cm.Blues(0.9)
lc = Line3DCollection(list(zip(start_points, end_points)), colors=C, alpha=0.007,lw=2)
ax.add_collection3d(lc)
ax.set_xlabel("z")
ax.set_ylabel("y")
ax.set_zlabel("x")
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
plt.show()
k = 0
plot_3d_with_edges(X[k], Y[k], X_clusters_graph[k]['X_cluster_in_out'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment