Skip to content

Instantly share code, notes, and snippets.

@ShaneFlandermeyer
Created June 1, 2024 02:27
Show Gist options
  • Save ShaneFlandermeyer/ab2352a65279c7be0a2cfeb1371127fb to your computer and use it in GitHub Desktop.
Save ShaneFlandermeyer/ab2352a65279c7be0a2cfeb1371127fb to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
jax_paths = ['/home/shane/Downloads/data.csv']
torch_path = '/home/shane/Downloads/humanoid-stand.csv'
torch_data = pd.read_csv(torch_path)
jax_data = [pd.read_csv(path) for path in jax_paths]
plt.figure()
# TODO: Will need to average over seeds for multiple results files
plt.plot(jax_data[0]['Step'], jax_data[0]['Value'])
valid_torch = np.logical_and(torch_data['seed'] == 3,torch_data['step'] < 1e6)
plt.plot(torch_data['step'][valid_torch], torch_data['reward'][valid_torch])
plt.xlabel('Environment interactions')
plt.ylabel('Reward')
plt.grid()
plt.legend(['tdmpc2-jax', 'tdmpc2'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment