Skip to content

Instantly share code, notes, and snippets.

@dongkwan-kim
Last active June 25, 2022 11:16
Show Gist options
  • Save dongkwan-kim/f9cba350c7df138a0f0a7848baff31d5 to your computer and use it in GitHub Desktop.
Save dongkwan-kim/f9cba350c7df138a0f0a7848baff31d5 to your computer and use it in GitHub Desktop.
from typing import Tuple, List
try:
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
from mpl_toolkits.mplot3d import Axes3D
except ImportError:
pass
import numpy as np
import pandas as pd
def create_fake_data(crossed_z=True):
X = [0.1, 0.3, 0.5, 0.7, 0.9]
Y = [1.0, 2.0, 4.0, 8.0, 16.0]
X, Y = np.meshgrid(X, Y)
Z1 = (X - 0.9) ** 2 + Y / 10
if crossed_z:
Z2 = X ** 2 + Y / 10
else:
Z2 = Z1 - 0.4
return X, Y, Z1, Z2
def format_custom_data(X, Y, Z_list: List) -> Tuple:
"""
:param X: e.g., [0.1, 0.3, 0.5, 0.7, 0.9]
:param Y: e.g., [1.0, 2.0, 4.0, 8.0, 16.0]
:param Z_list:
e.g., Z might be a matrix of shape (N_Y, N_X)
where meshed X is,
[[0.1 0.3 0.5 0.7 0.9]
[0.1 0.3 0.5 0.7 0.9]
[0.1 0.3 0.5 0.7 0.9]
[0.1 0.3 0.5 0.7 0.9]
[0.1 0.3 0.5 0.7 0.9]]
and meshed Y is,
[[ 1. 1. 1. 1. 1.]
[ 2. 2. 2. 2. 2.]
[ 4. 4. 4. 4. 4.]
[ 8. 8. 8. 8. 8.]
[16. 16. 16. 16. 16.]]
"""
X, Y = np.meshgrid(X, Y)
Z_list = [np.asarray(z) for z in Z_list]
assert X.shape == Y.shape == Z_list[0].shape
return tuple([X, Y, *Z_list])
def table_to_custom_xyz1z2_data(table: List[List[float]]) -> Tuple:
df = pd.DataFrame(table, columns=["X", "Y", "Z1", "Z2"])
X, Y = sorted(pd.unique(df.X)), sorted(pd.unique(df.Y))
idx_x = {v: i for i, v in enumerate(X)}
idx_y = {v: i for i, v in enumerate(Y)}
Z1 = np.zeros((len(Y), len(X))) - 1
Z2 = np.zeros((len(Y), len(X))) - 1
for _, row in df.iterrows():
ix, iy = idx_x[row.X], idx_y[row.Y]
Z1[(iy, ix)] = row.Z1
Z2[(iy, ix)] = row.Z2
return format_custom_data(X, Y, Z_list=[Z1, Z2])
def plot_two_surfaces_3d(X, Y, Z1, Z2, path=None):
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
def set_wireframe_and_surface(x, y, z, wireframe_cm, surface_rgb):
colors = wireframe_cm(plt.Normalize(z.min(), z.max())(z))
surf = ax.plot_surface(x, y, z,
facecolors=colors,
rstride=1, cstride=1, shade=True, linewidth=1)
surf.set_facecolor(surface_rgb)
set_wireframe_and_surface(X, Y, Z1, cm.spring, (1, 0.9, 0.9, 0.3))
set_wireframe_and_surface(X, Y, Z2, cm.winter, (0.9, 0.9, 1, 0.3))
# Customize the z axis.
# ax.set_zlim(-1.01, 1.01)
ax.zaxis.set_major_locator(LinearLocator(10))
# A StrMethodFormatter is used automatically
ax.zaxis.set_major_formatter('{x:.02f}')
plt.tight_layout()
if path is not None:
plt.savefig(path)
plt.show()
if __name__ == '__main__':
FROM = "TABLE"
if FROM == "MESH":
z = np.asarray([[1.1, 1.3, 1.5, 1.7, 1.9],
[2.1, 2.3, 2.5, 2.7, 2.9],
[4.1, 4.3, 4.5, 4.7, 4.9],
[8.1, 8.3, 8.5, 8.7, 8.9],
[16.1, 16.3, 16.5, 16.7, 16.9]])
plot_two_surfaces_3d(
*format_custom_data(
X=[0.1, 0.3, 0.5, 0.7, 0.9],
Y=[1.0, 2.0, 4.0, 8.0, 16.0],
Z_list=[z, z / 2],
),
path="./3d_mesh.pdf",
)
elif FROM == "TABLE":
plot_two_surfaces_3d(
*table_to_custom_xyz1z2_data(
# This is the table of
# X Y Z1 Z2
[[0.1, 1.0, 0.0, 0.1],
[0.1, 2.0, 0.3, 0.4],
[0.5, 1.0, 0.6, 0.7],
[0.5, 2.0, 0.9, 1.0],
[0.7, 1.0, 0.6, 0.7],
[0.7, 2.0, 0.9, 1.0]]
),
path="./3d_table.pdf",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment