Last active
June 22, 2020 13:54
-
-
Save RileyLazarou/6c34d3299ffca8d706afeafbca85efa6 to your computer and use it in GitHub Desktop.
vanilla gan discriminator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Discriminator(nn.Module): | |
def __init__(self, input_dim, layers): | |
"""A discriminator for discerning real from generated samples. | |
params: | |
input_dim (int): width of the input | |
layers (List[int]): A list of layer widths including output width | |
Output activation is Sigmoid. | |
""" | |
super(Discriminator, self).__init__() | |
self.input_dim = input_dim | |
self._init_layers(layers) | |
def _init_layers(self, layers): | |
"""Initialize the layers and store as self.module_list.""" | |
self.module_list = nn.ModuleList() | |
last_layer = self.input_dim | |
for index, width in enumerate(layers): | |
self.module_list.append(nn.Linear(last_layer, width)) | |
last_layer = width | |
if index + 1 != len(layers): | |
self.module_list.append(nn.LeakyReLU()) | |
else: | |
self.module_list.append(nn.Sigmoid()) | |
def forward(self, input_tensor): | |
"""Forward pass; map samples to confidence they are real [0, 1].""" | |
intermediate = input_tensor | |
for layer in self.module_list: | |
intermediate = layer(intermediate) | |
return intermediate |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment