Skip to content

Instantly share code, notes, and snippets.

@RileyLazarou
Last active June 22, 2020 13:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RileyLazarou/0c89b9167e3d1db716359508115cbcc9 to your computer and use it in GitHub Desktop.
Save RileyLazarou/0c89b9167e3d1db716359508115cbcc9 to your computer and use it in GitHub Desktop.
vanilla gan generator
class Generator(nn.Module):
def __init__(self, latent_dim, output_activation=None):
"""A generator for mapping a latent space to a sample space.
Args:
latent_dim (int): latent dimension ("noise vector")
layers (List[int]): A list of layer widths including output width
output_activation: torch activation function or None
"""
super(Generator, self).__init__()
self.linear1 = nn.Linear(latent_dim, 64)
self.leaky_relu = nn.LeakyReLU()
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, 1)
self.output_activation = output_activation
def forward(self, input_tensor):
"""Forward pass; map latent vectors to samples."""
intermediate = self.linear1(input_tensor)
intermediate = self.leaky_relu(intermediate)
intermediate = self.linear2(intermediate)
intermediate = self.leaky_relu(intermediate)
intermediate = self.linear3(intermediate)
if self.output_activation is not None:
intermediate = self.output_activation(intermediate)
return intermediate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment