Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created January 15, 2021 15:46
Show Gist options
  • Save jinglescode/69db7249a12dd7afc323b9a66592098f to your computer and use it in GitHub Desktop.
Save jinglescode/69db7249a12dd7afc323b9a66592098f to your computer and use it in GitHub Desktop.
class Generator(nn.Module):
'''
Generator Class
Parameters:
dim_noise: int, default: 10
the dimension of the noise vector
in_dim: int, default: 784
the dimension of the images, fitted for the dataset used
(MNIST images are 28x28, so 784 so is the default)
hidden_dim: int, default: 128
the inner dimension size
'''
def __init__(self, dim_noise=10, in_dim=784, hidden_dim=128):
super(Generator, self).__init__()
dims = [hidden_dim, hidden_dim*2, hidden_dim*4, hidden_dim*8]
self.gen = nn.Sequential(
self.generator_block(dim_noise, dims[0]),
self.generator_block(dims[0], dims[1]),
self.generator_block(dims[1], dims[2]),
self.generator_block(dims[2], dims[3]),
nn.Linear(dims[3], in_dim),
nn.Sigmoid()
)
def forward(self, noise):
return self.gen(noise)
def generator_block(self, input_dim, output_dim):
'''
A generator neural network layer, with a linear transformation, batchnorm and relu.
'''
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment