Skip to content

Instantly share code, notes, and snippets.

@cboettig
Last active June 4, 2021 05:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cboettig/27573c009238eb1b0ed1af3ab9e845cf to your computer and use it in GitHub Desktop.
Save cboettig/27573c009238eb1b0ed1af3ab9e845cf to your computer and use it in GitHub Desktop.
ppo puzzle
import numpy as np
import stable_baselines3 as sb3
from stable_baselines3.common.env_util import make_vec_env
from torch import nn as nn
import os
import gym
import gym_fishing
seed = 24
np.random.seed(seed)
env = make_vec_env("fishing-v1", n_envs=4, seed=seed)
# Create a fishing environment
hyper = {
"batch_size": 64,
"n_steps": 32,
"gamma": 0.9999,
"learning_rate": .00037996229760279524,
"ent_coef": 2.392750198613048e-08,
"clip_range": 0.3,
"n_epochs": 10,
"gae_lambda": 0.92,
"max_grad_norm": 0.7,
"vf_coef": 0.4863551514846155,
"policy_kwargs": {
"net_arch": [64, 64],
"activation_fn": nn.ReLU,
}
}
# Create an agent
agent = sb3.PPO("MlpPolicy", env, seed=seed, **hyper)
# Train the agent
if not os.path.exists("PPO_tuned.zip"):
agent.learn(total_timesteps=300000)
agent.save("PPO_tuned")
# Evaluate the trained agent
env = gym.make("fishing-v1")
agent = sb3.PPO.load("PPO_tuned")
agent_sims = env.simulate(agent, reps=100)
agent_policy = env.policyfn(agent, reps=5)
# Plot results
env.plot(agent_sims, "fishing_PPO_tuned.png")
env.plot_policy(agent_policy, "fishing_PPO_tuned_policy.png")
absl-py==0.12.0
alabaster==0.7.12
alembic==1.6.5
apipkg==1.5
appdirs==1.4.4
atari-py==0.2.6
attrs==21.2.0
Babel==2.9.1
black==21.5b2
box2d-py==2.3.8
cachetools==4.2.2
certifi==2021.5.30
chardet==4.0.0
click==8.0.1
cliff==3.8.0
cloudpickle==1.6.0
cmaes==0.8.2
cmd2==1.5.0
colorama==0.4.4
colorlog==5.0.1
coverage==5.5
cycler==0.10.0
DataProperty==0.50.1
decorator==4.4.2
docutils==0.16
execnet==1.8.1
flake8==3.9.2
flake8-bugbear==21.4.3
google-auth==1.30.1
google-auth-oauthlib==0.4.4
greenlet==1.1.0
grpcio==1.38.0
gym==0.18.3
gym-conservation==0.0.5
gym-fishing==0.1.0
gym-minigrid==1.0.2
idna==2.10
imagesize==1.2.0
importlab==0.6.1
iniconfig==1.1.1
isort==5.8.0
Jinja2==3.0.1
joblib==1.0.1
kiwisolver==1.3.1
livereload==2.6.3
Mako==1.1.4
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.2
mbstrdecoder==1.0.1
mccabe==0.6.1
msgfy==0.1.0
mypy-extensions==0.4.3
networkx==2.5.1
ninja==1.10.0.post2
numpy==1.20.3
oauthlib==3.1.1
opencv-python==4.5.2.52
optuna==2.7.0
packaging==20.9
pandas==1.2.4
pathspec==0.8.1
pathvalidate==2.4.1
pbr==5.6.0
Pillow==8.2.0
pkg-resources==0.0.0
pluggy==0.13.1
prettytable==2.1.0
protobuf==3.17.1
psutil==5.8.0
py==1.10.0
pyaml==20.4.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybullet==3.1.7
pycodestyle==2.7.0
pyenchant==3.2.0
pyflakes==2.3.1
pyglet==1.5.15
Pygments==2.9.0
pyparsing==2.4.7
pyperclip==1.8.2
pytablewriter==0.60.0
pytest==6.2.4
pytest-cov==2.12.1
pytest-env==0.6.2
pytest-forked==1.3.0
pytest-xdist==2.2.1
python-dateutil==2.8.1
python-editor==1.0.4
pytype==2021.5.25
pytz==2021.1
PyYAML==5.4.1
regex==2021.4.4
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
sb3-contrib==1.1.0a7
scikit-learn==0.24.2
scikit-optimize==0.8.1
scipy==1.6.3
seaborn==0.11.1
six==1.16.0
snowballstemmer==2.1.0
Sphinx==4.0.2
sphinx-autobuild==2021.3.14
sphinx-autodoc-typehints==1.12.0
sphinx-rtd-theme==0.5.2
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib-spelling==7.2.1
SQLAlchemy==1.4.17
stable-baselines3==1.1.0a7
stevedore==3.3.0
tabledata==1.1.4
tcolorpy==0.0.9
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
threadpoolctl==2.1.0
toml==0.10.2
torch==1.8.1
tornado==6.1
tqdm==4.61.0
typed-ast==1.4.3
typepy==1.1.5
typing-extensions==3.10.0.0
urllib3==1.26.5
wcwidth==0.2.5
Werkzeug==2.0.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment