Last active
June 4, 2021 05:07
-
-
Save cboettig/27573c009238eb1b0ed1af3ab9e845cf to your computer and use it in GitHub Desktop.
ppo puzzle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import stable_baselines3 as sb3 | |
from stable_baselines3.common.env_util import make_vec_env | |
from torch import nn as nn | |
import os | |
import gym | |
import gym_fishing | |
seed = 24 | |
np.random.seed(seed) | |
env = make_vec_env("fishing-v1", n_envs=4, seed=seed) | |
# Create a fishing environment | |
hyper = { | |
"batch_size": 64, | |
"n_steps": 32, | |
"gamma": 0.9999, | |
"learning_rate": .00037996229760279524, | |
"ent_coef": 2.392750198613048e-08, | |
"clip_range": 0.3, | |
"n_epochs": 10, | |
"gae_lambda": 0.92, | |
"max_grad_norm": 0.7, | |
"vf_coef": 0.4863551514846155, | |
"policy_kwargs": { | |
"net_arch": [64, 64], | |
"activation_fn": nn.ReLU, | |
} | |
} | |
# Create an agent | |
agent = sb3.PPO("MlpPolicy", env, seed=seed, **hyper) | |
# Train the agent | |
if not os.path.exists("PPO_tuned.zip"): | |
agent.learn(total_timesteps=300000) | |
agent.save("PPO_tuned") | |
# Evaluate the trained agent | |
env = gym.make("fishing-v1") | |
agent = sb3.PPO.load("PPO_tuned") | |
agent_sims = env.simulate(agent, reps=100) | |
agent_policy = env.policyfn(agent, reps=5) | |
# Plot results | |
env.plot(agent_sims, "fishing_PPO_tuned.png") | |
env.plot_policy(agent_policy, "fishing_PPO_tuned_policy.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
absl-py==0.12.0 | |
alabaster==0.7.12 | |
alembic==1.6.5 | |
apipkg==1.5 | |
appdirs==1.4.4 | |
atari-py==0.2.6 | |
attrs==21.2.0 | |
Babel==2.9.1 | |
black==21.5b2 | |
box2d-py==2.3.8 | |
cachetools==4.2.2 | |
certifi==2021.5.30 | |
chardet==4.0.0 | |
click==8.0.1 | |
cliff==3.8.0 | |
cloudpickle==1.6.0 | |
cmaes==0.8.2 | |
cmd2==1.5.0 | |
colorama==0.4.4 | |
colorlog==5.0.1 | |
coverage==5.5 | |
cycler==0.10.0 | |
DataProperty==0.50.1 | |
decorator==4.4.2 | |
docutils==0.16 | |
execnet==1.8.1 | |
flake8==3.9.2 | |
flake8-bugbear==21.4.3 | |
google-auth==1.30.1 | |
google-auth-oauthlib==0.4.4 | |
greenlet==1.1.0 | |
grpcio==1.38.0 | |
gym==0.18.3 | |
gym-conservation==0.0.5 | |
gym-fishing==0.1.0 | |
gym-minigrid==1.0.2 | |
idna==2.10 | |
imagesize==1.2.0 | |
importlab==0.6.1 | |
iniconfig==1.1.1 | |
isort==5.8.0 | |
Jinja2==3.0.1 | |
joblib==1.0.1 | |
kiwisolver==1.3.1 | |
livereload==2.6.3 | |
Mako==1.1.4 | |
Markdown==3.3.4 | |
MarkupSafe==2.0.1 | |
matplotlib==3.4.2 | |
mbstrdecoder==1.0.1 | |
mccabe==0.6.1 | |
msgfy==0.1.0 | |
mypy-extensions==0.4.3 | |
networkx==2.5.1 | |
ninja==1.10.0.post2 | |
numpy==1.20.3 | |
oauthlib==3.1.1 | |
opencv-python==4.5.2.52 | |
optuna==2.7.0 | |
packaging==20.9 | |
pandas==1.2.4 | |
pathspec==0.8.1 | |
pathvalidate==2.4.1 | |
pbr==5.6.0 | |
Pillow==8.2.0 | |
pkg-resources==0.0.0 | |
pluggy==0.13.1 | |
prettytable==2.1.0 | |
protobuf==3.17.1 | |
psutil==5.8.0 | |
py==1.10.0 | |
pyaml==20.4.0 | |
pyasn1==0.4.8 | |
pyasn1-modules==0.2.8 | |
pybullet==3.1.7 | |
pycodestyle==2.7.0 | |
pyenchant==3.2.0 | |
pyflakes==2.3.1 | |
pyglet==1.5.15 | |
Pygments==2.9.0 | |
pyparsing==2.4.7 | |
pyperclip==1.8.2 | |
pytablewriter==0.60.0 | |
pytest==6.2.4 | |
pytest-cov==2.12.1 | |
pytest-env==0.6.2 | |
pytest-forked==1.3.0 | |
pytest-xdist==2.2.1 | |
python-dateutil==2.8.1 | |
python-editor==1.0.4 | |
pytype==2021.5.25 | |
pytz==2021.1 | |
PyYAML==5.4.1 | |
regex==2021.4.4 | |
requests==2.25.1 | |
requests-oauthlib==1.3.0 | |
rsa==4.7.2 | |
sb3-contrib==1.1.0a7 | |
scikit-learn==0.24.2 | |
scikit-optimize==0.8.1 | |
scipy==1.6.3 | |
seaborn==0.11.1 | |
six==1.16.0 | |
snowballstemmer==2.1.0 | |
Sphinx==4.0.2 | |
sphinx-autobuild==2021.3.14 | |
sphinx-autodoc-typehints==1.12.0 | |
sphinx-rtd-theme==0.5.2 | |
sphinxcontrib-applehelp==1.0.2 | |
sphinxcontrib-devhelp==1.0.2 | |
sphinxcontrib-htmlhelp==2.0.0 | |
sphinxcontrib-jsmath==1.0.1 | |
sphinxcontrib-qthelp==1.0.3 | |
sphinxcontrib-serializinghtml==1.1.5 | |
sphinxcontrib-spelling==7.2.1 | |
SQLAlchemy==1.4.17 | |
stable-baselines3==1.1.0a7 | |
stevedore==3.3.0 | |
tabledata==1.1.4 | |
tcolorpy==0.0.9 | |
tensorboard==2.5.0 | |
tensorboard-data-server==0.6.1 | |
tensorboard-plugin-wit==1.8.0 | |
threadpoolctl==2.1.0 | |
toml==0.10.2 | |
torch==1.8.1 | |
tornado==6.1 | |
tqdm==4.61.0 | |
typed-ast==1.4.3 | |
typepy==1.1.5 | |
typing-extensions==3.10.0.0 | |
urllib3==1.26.5 | |
wcwidth==0.2.5 | |
Werkzeug==2.0.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment