Skip to content

Instantly share code, notes, and snippets.

@RileyLazarou
Last active June 22, 2020 13:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RileyLazarou/6c34d3299ffca8d706afeafbca85efa6 to your computer and use it in GitHub Desktop.
Save RileyLazarou/6c34d3299ffca8d706afeafbca85efa6 to your computer and use it in GitHub Desktop.
vanilla gan discriminator
class Discriminator(nn.Module):
def __init__(self, input_dim, layers):
"""A discriminator for discerning real from generated samples.
params:
input_dim (int): width of the input
layers (List[int]): A list of layer widths including output width
Output activation is Sigmoid.
"""
super(Discriminator, self).__init__()
self.input_dim = input_dim
self._init_layers(layers)
def _init_layers(self, layers):
"""Initialize the layers and store as self.module_list."""
self.module_list = nn.ModuleList()
last_layer = self.input_dim
for index, width in enumerate(layers):
self.module_list.append(nn.Linear(last_layer, width))
last_layer = width
if index + 1 != len(layers):
self.module_list.append(nn.LeakyReLU())
else:
self.module_list.append(nn.Sigmoid())
def forward(self, input_tensor):
"""Forward pass; map samples to confidence they are real [0, 1]."""
intermediate = input_tensor
for layer in self.module_list:
intermediate = layer(intermediate)
return intermediate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment