Skip to content

Instantly share code, notes, and snippets.

@ugo-nama-kun
Last active January 13, 2023 04:53
Show Gist options
  • Save ugo-nama-kun/0613997a528e458231f825191123af52 to your computer and use it in GitHub Desktop.
Save ugo-nama-kun/0613997a528e458231f825191123af52 to your computer and use it in GitHub Desktop.
Beta dist policy
class BetaHead(nn.Module):
def __init__(self, in_features, action_size):
super(BetaHead, self).__init__()
self.fcc_c0 = nn.Linear(in_features, action_size)
nn.init.orthogonal_(self.fcc_c0.weight, gain=0.01)
nn.init.zeros_(self.fcc_c0.bias)
self.fcc_c1 = nn.Linear(in_features, action_size)
nn.init.orthogonal_(self.fcc_c1.weight, gain=0.01)
nn.init.zeros_(self.fcc_c1.bias)
def forward(self, x):
c0 = torch.nn.functional.softplus(self.fcc_c0(x)) + 1.
c1 = torch.nn.functional.softplus(self.fcc_c1(x)) + 1.
return torch.distributions.Independent(
torch.distributions.Beta(c1, c0), 1
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment