Skip to content

Instantly share code, notes, and snippets.

@Kaixhin
Last active March 26, 2020 16:26
Show Gist options
  • Save Kaixhin/c88235ae4e8a4608aa7e297699b05399 to your computer and use it in GitHub Desktop.
Save Kaixhin/c88235ae4e8a4608aa7e297699b05399 to your computer and use it in GitHub Desktop.
Tensor plotting functions
from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d
def scatter(X, Y, c=None, ax=None):
ax = plt.axes() if ax is None else ax
ax.scatter(X.numpy(), Y.numpy(), c=None if c is None else c.numpy())
return ax
def contour(X, Y, Z, levels=None, ax=None):
ax = plt.axes() if ax is None else ax
ax.contour(X.numpy(), Y.numpy(), Z.numpy(), levels=levels)
return ax
def scatter3D(X, Y, Z, ax=None):
ax = plt.axes(projection='3d') if ax is None else ax
ax.scatter3D(X.numpy(), Y.numpy(), Z.numpy(), c=Z.numpy())
return ax
def plot3D(X, Y, Z, ax=None):
ax = plt.axes(projection='3d') if ax is None else ax
ax.plot3D(X.numpy(), Y.numpy(), Z.numpy(), 'gray')
return ax
#def surface(X, Y, Z, rstride=1, cstride=1, ax=None):
# ax = plt.axes(projection='3d') if ax is None else ax
# ax.plot_surface(X.numpy(), Y.numpy(), Z.numpy(), rstride=rstride, cstride=cstride)
# return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment