Created
June 4, 2024 02:52
-
-
Save edwhu/710d5202435831c5c761ea53d63a952e to your computer and use it in GitHub Desktop.
Compare tensorboard plots from tdmpc2-jax and the original csv from tdmpc2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
runs = [ | |
"/Users/edward/Downloads/run-humanoid-stand_s0_2024-05-31_22-01-27_tensorboard-tag-episode_return.csv", | |
"/Users/edward/Downloads/run-humanoid-stand_s1_2024-06-01_13-11-59_tensorboard-tag-episode_return.csv", | |
"/Users/edward/Downloads/run-humanoid-stand_s2_2024-06-02_04-22-15_tensorboard-tag-episode_return.csv", | |
"/Users/edward/Downloads/run-humanoid-stand_s3_2024-06-02_19-37-02_tensorboard-tag-episode_return.csv", | |
] | |
tdmpc2 = [ | |
"/Users/edward/Downloads/humanoid-stand.csv" | |
] | |
# Read the csv into pandas | |
dfs = [pd.read_csv(run) for run in runs] | |
# Plot each DataFrame as a line on the graph with a faint color | |
for df in dfs: | |
plt.plot(df['Step'], df['Value'], color='blue', alpha=0.2) | |
# Calculate the mean of all DataFrames for each step | |
mean_df = pd.concat(dfs).groupby(level=0).mean() | |
# Plot the mean as a line on the graph with a darker color | |
plt.plot(mean_df['Step'], mean_df['Value'], color='blue', alpha=1, label='jax') | |
tdmpc2_df = pd.read_csv(tdmpc2[0]) | |
# Group the DataFrame by the 'seed' column and plot each group as a line on the graph with a faint orange color | |
for name, group in tdmpc2_df.groupby('seed'): | |
plt.plot(group['step'], group['reward'], color='orange', alpha=0.2) | |
# Calculate the mean of all groups for each step | |
mean_tdmpc2_df = tdmpc2_df.groupby('step').mean() | |
# Plot the mean as a line on the graph with a normal orange color | |
plt.plot(mean_tdmpc2_df.index, mean_tdmpc2_df['reward'], color='orange', alpha=1, label='original') | |
# Set x-axis limit | |
plt.xlim(0, 4000000) | |
plt.legend() | |
# Show the plot | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment