Skip to content

Instantly share code, notes, and snippets.

@araffin
Created September 18, 2018 09:43
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save araffin/ee9daee110af3b837b0e3a46a6bb403b to your computer and use it in GitHub Desktop.
Save araffin/ee9daee110af3b837b0e3a46a6bb403b to your computer and use it in GitHub Desktop.
import pytest
import numpy as np
from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, TRPO
from stable_baselines.common import set_global_seeds
MODEL_LIST_DISCRETE = [
A2C,
ACER,
ACKTR,
DQN,
PPO1,
PPO2,
TRPO
]
@pytest.mark.parametrize("model_class", MODEL_LIST_DISCRETE)
def test_perf_cartpole(model_class):
"""
Test if the algorithm (with a given policy)
can learn something on the simple CartPole environment
:param model_class: (BaseRLModel) A model
"""
# TODO: multiprocess if possible
model = model_class(policy="MlpPolicy", env='CartPole-v1',
tensorboard_log="/tmp/log/perf/cartpole")
model.learn(total_timesteps=int(1e5), seed=0)
env = model.get_env()
n_trials = 2000
set_global_seeds(0)
obs = env.reset()
episode_rewards = []
reward_sum = 0
for _ in range(n_trials):
action, _ = model.predict(obs)
obs, reward, done, _ = env.step(action)
reward_sum += reward
if done:
episode_rewards.append(reward_sum)
reward_sum = 0
assert np.mean(episode_rewards) >= 100
# Free memory
del model, env
pytest -v cartpole_bench.py
tensorboard --logdir /tmp/log/perf/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment