This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SPADE(Module): | |
def __init__(self, args, k): | |
super().__init__() | |
num_filters = args.spade_filter | |
kernel_size = args.spade_kernel | |
self.conv = spectral_norm(Conv2d(1, num_filters, kernel_size=(kernel_size, kernel_size), padding=1)) | |
self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1)) | |
self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1)) | |
def forward(self, x, seg): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SPADEResBlk(Module): | |
def __init__(self, args, k, skip=False): | |
super().__init__() | |
kernel_size = args.spade_resblk_kernel | |
self.skip = skip | |
if self.skip: | |
self.spade1 = SPADE(args, 2*k) | |
self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False) | |
self.spade_skip = SPADE(args, 2*k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SPADEGenerator(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.linear = Linear(args.gen_input_size, args.gen_hidden_size) | |
self.spade_resblk1 = SPADEResBlk(args, 1024) | |
self.spade_resblk2 = SPADEResBlk(args, 1024) | |
self.spade_resblk3 = SPADEResBlk(args, 1024) | |
self.spade_resblk4 = SPADEResBlk(args, 512) | |
self.spade_resblk5 = SPADEResBlk(args, 256) | |
self.spade_resblk6 = SPADEResBlk(args, 128) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def custom_model1(in_chan, out_chan): | |
return nn.Sequential( | |
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=2, padding=1)), | |
nn.LeakyReLU(inplace=True) | |
) | |
def custom_model2(in_chan, out_chan, stride=2): | |
return nn.Sequential( | |
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=stride, padding=1)), | |
nn.InstanceNorm2d(out_chan), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VGGLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.vgg = VGG19().cuda() | |
self.criterion = nn.L1Loss() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, | |
tensor=torch.FloatTensor): | |
super().__init__() | |
self.real_label = target_real_label | |
self.fake_label = target_fake_label | |
self.real_label_var = None | |
self.fake_label_var = None | |
self.Tensor = tensor | |
if use_lsgan: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv_inst_lrelu(in_chan, out_chan): | |
return nn.Sequential( | |
nn.Conv2d(in_chan, out_chan, kernel_size=(3,3), stride=2, bias=False, padding=1), | |
nn.InstanceNorm2d(out_chan), | |
nn.LeakyReLU(inplace=True) | |
) | |
class SPADEEncoder(nn.Module): | |
def __init__(self, args): | |
super().__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
raise Exception('GPU is not available') | |
# Load VGG19 features. We do not need the last linear layers, | |
# only CNN layers are needed | |
vgg = vgg19(pretrained=True).features | |
vgg = vgg.to(device) | |
# We don't want to train VGG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
content_img = load_image(os.path.join(args.img_root, args.content_img), size=500) | |
content_img = content_img.to(device) | |
style_img = load_image(os.path.join(args.img_root, args.style_img)) | |
style_img = style_img.to(device) | |
# Show content and style image | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10)) | |
ax1.imshow(im_convert(content_img)) | |
ax2.imshow(im_convert(style_img)) |
OlderNewer