Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active October 7, 2020 07:17
Show Gist options
  • Save djbyrne/4c0891d408b7ca6600af6009f1b69daf to your computer and use it in GitHub Desktop.
Save djbyrne/4c0891d408b7ca6600af6009f1b69daf to your computer and use it in GitHub Desktop.
def __init__(self, env: str, gamma: float = 0.99, lr: float = 1e-2, batch_size: int = 8,
n_steps: int = 10, avg_reward_len: int = 100, entropy_beta: float = 0.01,
epoch_len: int = 1000, *args, **kwargs) -> None:
super().__init__()
# Model components
self.env = gym.make(env)
self.net = MLP(self.env.observation_space.shape, self.env.action_space.n)
self.agent = PolicyAgent(self.net)
# Hyperparameters
self.lr = lr
self.batch_size = batch_size
self.batches_per_epoch = self.batch_size * epoch_len
self.entropy_beta = entropy_beta
self.gamma = gamma
self.n_steps = n_steps
self.save_hyperparameters()
# Tracking metrics
self.total_rewards = []
self.episode_rewards = []
self.done_episodes = 0
self.avg_rewards = 0
self.avg_reward_len = avg_reward_len
self.eps = np.finfo(np.float32).eps.item()
self.batch_states = []
self.batch_actions = []
self.state = self.env.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment