Skip to content

Instantly share code, notes, and snippets.

@udithhaputhanthri
Created March 26, 2021 08:23
Show Gist options
  • Save udithhaputhanthri/77546d8d1dce83c87e28dcda9e19b3f4 to your computer and use it in GitHub Desktop.
Save udithhaputhanthri/77546d8d1dce83c87e28dcda9e19b3f4 to your computer and use it in GitHub Desktop.
WGAN_scripts
class Generator(nn.Module):
def __init__(self,noise_channels,img_channels,hidden_G):
super(Generator,self).__init__()
self.G=nn.Sequential(
conv_trans_block(noise_channels,hidden_G*16,kernal_size=4,stride=1,padding=0),
conv_trans_block(hidden_G*16,hidden_G*8),
conv_trans_block(hidden_G*8,hidden_G*4),
conv_trans_block(hidden_G*4,hidden_G*2),
nn.ConvTranspose2d(hidden_G*2,img_channels,kernel_size=4,stride=2,padding=1),
nn.Tanh()
)
def forward(self,x):
return self.G(x)
class Critic(nn.Module):
def __init__(self,img_channels,hidden_D):
super(Critic,self).__init__()
self.D=nn.Sequential(
conv_block(img_channels,hidden_G),
conv_block(hidden_G,hidden_G*2),
conv_block(hidden_G*2,hidden_G*4),
conv_block(hidden_G*4,hidden_G*8),
nn.Conv2d(hidden_G*8,1,kernel_size=4,stride=2,padding=0))
def forward(self,x):
return self.D(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment