-
-
Save ShaneFlandermeyer/d37ba661d06e7e04cb25a195623ac007 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
base_dir = "/home/shane/onedrive/tdmpc2-comparison" | |
runs = [ | |
# finger-turn-hard | |
f"{base_dir}/outputs_finger-turn-hard_s1_2024-06-03_14-46-41_tensorboard.csv", | |
f"{base_dir}/outputs_finger-turn-hard_s2_2024-06-03_18-34-16_tensorboard.csv", | |
f"{base_dir}/outputs_finger-turn-hard_s3_2024-06-03_22-18-35_tensorboard.csv" | |
] | |
tdmpc2 = [ | |
f"{base_dir}/finger-turn-hard.csv" | |
] | |
# Read the csv into pandas | |
dfs = [pd.read_csv(run) for run in runs] | |
# Plot each DataFrame as a line on the graph with a faint color | |
for df in dfs: | |
plt.plot(df['Step'], df['Value'], color='blue', alpha=0.2) | |
# Calculate the mean of all DataFrames for each step | |
mean_df = pd.concat(dfs).groupby(level=0).mean() | |
# Plot the mean as a line on the graph with a darker color | |
plt.plot(mean_df['Step'], mean_df['Value'], color='blue', | |
alpha=1, label='tdmpc2-jax', linewidth=2) | |
tdmpc2_df = pd.read_csv(tdmpc2[0]) | |
# Group the DataFrame by the 'seed' column and plot each group as a line on the graph with a faint orange color | |
for name, group in tdmpc2_df.groupby('seed'): | |
plt.plot(group['step'], group['reward'], color='orange', alpha=0.2) | |
# Calculate the mean of all groups for each step | |
mean_tdmpc2_df = tdmpc2_df.groupby('step').mean() | |
# Plot the mean as a line on the graph with a normal orange color | |
plt.plot(mean_tdmpc2_df.index, | |
mean_tdmpc2_df['reward'], color='orange', alpha=1, label='Original', linewidth=2) | |
# Set x-axis limit | |
plt.xlim(mean_df['Step'].min(), mean_df['Step'].max()) | |
plt.legend() | |
plt.xlabel('Environment Interactions', fontsize=12, fontweight='bold') | |
plt.ylabel('Episodic Return', fontsize=12, fontweight='bold') | |
plt.xlim(0, 2e6) | |
# Show the plot | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment