Last active
June 22, 2020 13:54
-
-
Save RileyLazarou/0c89b9167e3d1db716359508115cbcc9 to your computer and use it in GitHub Desktop.
vanilla gan generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Generator(nn.Module): | |
def __init__(self, latent_dim, output_activation=None): | |
"""A generator for mapping a latent space to a sample space. | |
Args: | |
latent_dim (int): latent dimension ("noise vector") | |
layers (List[int]): A list of layer widths including output width | |
output_activation: torch activation function or None | |
""" | |
super(Generator, self).__init__() | |
self.linear1 = nn.Linear(latent_dim, 64) | |
self.leaky_relu = nn.LeakyReLU() | |
self.linear2 = nn.Linear(64, 32) | |
self.linear3 = nn.Linear(32, 1) | |
self.output_activation = output_activation | |
def forward(self, input_tensor): | |
"""Forward pass; map latent vectors to samples.""" | |
intermediate = self.linear1(input_tensor) | |
intermediate = self.leaky_relu(intermediate) | |
intermediate = self.linear2(intermediate) | |
intermediate = self.leaky_relu(intermediate) | |
intermediate = self.linear3(intermediate) | |
if self.output_activation is not None: | |
intermediate = self.output_activation(intermediate) | |
return intermediate |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment