Skip to content

Instantly share code, notes, and snippets.

@djbyrne
Last active March 26, 2020 08:12
Show Gist options
  • Save djbyrne/fd0811bfb8c0fb1d53f9b21dbe4772ef to your computer and use it in GitHub Desktop.
Save djbyrne/fd0811bfb8c0fb1d53f9b21dbe4772ef to your computer and use it in GitHub Desktop.
Simple network to be used with a DQN
class DQN(nn.Module):
"""
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)
def forward(self, x):
return self.net(x.float())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment