Skip to content

Instantly share code, notes, and snippets.

@ShaneFlandermeyer
Created June 7, 2024 15:11
Show Gist options
  • Save ShaneFlandermeyer/d37ba661d06e7e04cb25a195623ac007 to your computer and use it in GitHub Desktop.
Save ShaneFlandermeyer/d37ba661d06e7e04cb25a195623ac007 to your computer and use it in GitHub Desktop.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
base_dir = "/home/shane/onedrive/tdmpc2-comparison"
runs = [
# finger-turn-hard
f"{base_dir}/outputs_finger-turn-hard_s1_2024-06-03_14-46-41_tensorboard.csv",
f"{base_dir}/outputs_finger-turn-hard_s2_2024-06-03_18-34-16_tensorboard.csv",
f"{base_dir}/outputs_finger-turn-hard_s3_2024-06-03_22-18-35_tensorboard.csv"
]
tdmpc2 = [
f"{base_dir}/finger-turn-hard.csv"
]
# Read the csv into pandas
dfs = [pd.read_csv(run) for run in runs]
# Plot each DataFrame as a line on the graph with a faint color
for df in dfs:
plt.plot(df['Step'], df['Value'], color='blue', alpha=0.2)
# Calculate the mean of all DataFrames for each step
mean_df = pd.concat(dfs).groupby(level=0).mean()
# Plot the mean as a line on the graph with a darker color
plt.plot(mean_df['Step'], mean_df['Value'], color='blue',
alpha=1, label='tdmpc2-jax', linewidth=2)
tdmpc2_df = pd.read_csv(tdmpc2[0])
# Group the DataFrame by the 'seed' column and plot each group as a line on the graph with a faint orange color
for name, group in tdmpc2_df.groupby('seed'):
plt.plot(group['step'], group['reward'], color='orange', alpha=0.2)
# Calculate the mean of all groups for each step
mean_tdmpc2_df = tdmpc2_df.groupby('step').mean()
# Plot the mean as a line on the graph with a normal orange color
plt.plot(mean_tdmpc2_df.index,
mean_tdmpc2_df['reward'], color='orange', alpha=1, label='Original', linewidth=2)
# Set x-axis limit
plt.xlim(mean_df['Step'].min(), mean_df['Step'].max())
plt.legend()
plt.xlabel('Environment Interactions', fontsize=12, fontweight='bold')
plt.ylabel('Episodic Return', fontsize=12, fontweight='bold')
plt.xlim(0, 2e6)
# Show the plot
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment