Skip to content

Instantly share code, notes, and snippets.

@nagataka
Last active April 10, 2020 01:13
Show Gist options
  • Save nagataka/dfd966951374aaa194401a7f3bb215dd to your computer and use it in GitHub Desktop.
Save nagataka/dfd966951374aaa194401a7f3bb215dd to your computer and use it in GitHub Desktop.
A template to start a project using OpenAI gym with PyTorch
"""A template to implement RL agent with OpenAI Gym
Usage: python ./gym_template.py --env=CarRacing-v0 --algo=policy_gradient --epochs 1
implementation of algorithms need to be ./algorithms/ directory, or change the following line to your env
> algo = import_module('algorithms.'+args.algo)
"""
import argparse
import numpy as np
import gym
from importlib import import_module
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def main():
parser = argparse.ArgumentParser(description="Main for training RL agent")
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
help='number of epochs to train (default: 1,000)')
parser.add_argument('--env', type=str, default=None, help="https://github.com/openai/gym/wiki/Table-of-environments")
parser.add_argument('--algo', type=str, default='dqn', help="learning algorithm")
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env_str = args.env
env = gym.make(env_str)
if isinstance(self.env.action_space, gym.spaces.box.Box):
self.num_actions = self.env.action_space.shape[0]
else:
self.num_actions = self.env.action_space.n
num_observations = env.observation_space
print("Created {} env which has {} actions in {} spaces".format(env_str, num_actions, num_observations) )
algo = import_module('algorithms.'+args.algo)
agent = Agent(algo, env)
agent.train(args.epochs)
class Agent():
def __init__(self, algo, env):
self.algo = algo
self.env = env
def train(self, num_epochs):
for e in range(num_epochs):
# initialization
state = self.env.reset()
done = False
total_reward = 0
self.env.render()
while not done:
action = self.env.action_space.sample()
next_state, reward, done, _ = self.env.step(action)
self.env.render()
total_reward += reward
state = next_state
print("Done with reward ", total_reward)
self.env.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment