Skip to content

Instantly share code, notes, and snippets.

@JIElite
Created April 13, 2018 18:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JIElite/174450bb9aff72a92ec698a85e66b6fc to your computer and use it in GitHub Desktop.
Save JIElite/174450bb9aff72a92ec698a85e66b6fc to your computer and use it in GitHub Desktop.
class FullyConv(nn.Module):
def __init__(self, screen_channels, screen_resolution):
super(FullyConv, self).__init__()
self.conv1 = nn.Conv2d(screen_channels, 16, kernel_size=(5, 5), stride=1, padding=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=1, padding=1)
self.spatial_policy = nn.Conv2d(32, 1, kernel_size=(1, 1))
self.non_spatial_branch = nn.Linear(screen_resolution[0] * screen_resolution[1] * 32, 256)
self.value = nn.Linear(256, 1)
# init weight
nn.init.xavier_uniform(self.conv1.weight.data)
nn.init.xavier_uniform(self.conv2.weight.data)
nn.init.xavier_uniform(self.spatial_policy.weight.data)
nn.utils.weight_norm(self.non_spatial_branch)
nn.utils.weight_norm(self.value)
self.non_spatial_branch.bias.data.fill_(0)
self.value.bias.data.fill_(0)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
# spatial policy branch
policy_branch = self.spatial_policy(x)
policy_branch = policy_branch.view(policy_branch.shape[0], -1)
action_prob = nn.functional.softmax(policy_branch, dim=1)
# non spatial branch
non_spatial_represenatation = F.relu(self.non_spatial_branch(x.view(-1))) # flatten the state representation
value = self.value(non_spatial_represenatation)
return action_prob, value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment