Skip to content

Instantly share code, notes, and snippets.

View KushajveerSingh's full-sized avatar
🇺🇸
Working from home

Kushajveer Singh KushajveerSingh

🇺🇸
Working from home
View GitHub Profile
@KushajveerSingh
KushajveerSingh / SPADE
Created April 19, 2019 07:49
SPADE model from the paper 1903.07291, my implementation.
class SPADE(Module):
def __init__(self, args, k):
super().__init__()
num_filters = args.spade_filter
kernel_size = args.spade_kernel
self.conv = spectral_norm(Conv2d(1, num_filters, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
def forward(self, x, seg):
@KushajveerSingh
KushajveerSingh / SPADEResBlk
Last active April 19, 2019 08:35
SPADEResBlk model from the paper 1903.07291, my implementation
class SPADEResBlk(Module):
def __init__(self, args, k, skip=False):
super().__init__()
kernel_size = args.spade_resblk_kernel
self.skip = skip
if self.skip:
self.spade1 = SPADE(args, 2*k)
self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
self.spade_skip = SPADE(args, 2*k)
@KushajveerSingh
KushajveerSingh / He weight init
Created April 19, 2019 08:41
Using He initialization for conv model
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
@KushajveerSingh
KushajveerSingh / SPADEGenerator
Created April 19, 2019 08:47
SPADEGenerator implementation from the paper 1903.07291, my implementation
class SPADEGenerator(nn.Module):
def __init__(self, args):
super().__init__()
self.linear = Linear(args.gen_input_size, args.gen_hidden_size)
self.spade_resblk1 = SPADEResBlk(args, 1024)
self.spade_resblk2 = SPADEResBlk(args, 1024)
self.spade_resblk3 = SPADEResBlk(args, 1024)
self.spade_resblk4 = SPADEResBlk(args, 512)
self.spade_resblk5 = SPADEResBlk(args, 256)
self.spade_resblk6 = SPADEResBlk(args, 128)
@KushajveerSingh
KushajveerSingh / SPADEDiscriminator
Created April 19, 2019 08:50
SPADEDiscriminatorimplementation from the paper 1903.07291, my implementation
def custom_model1(in_chan, out_chan):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=2, padding=1)),
nn.LeakyReLU(inplace=True)
)
def custom_model2(in_chan, out_chan, stride=2):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=stride, padding=1)),
nn.InstanceNorm2d(out_chan),
@KushajveerSingh
KushajveerSingh / PerceptualLoss
Last active April 19, 2019 09:36
PerceptualLoss to get features extracted from pretrained VGG19
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = VGG19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super().__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
def conv_inst_lrelu(in_chan, out_chan):
return nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=(3,3), stride=2, bias=False, padding=1),
nn.InstanceNorm2d(out_chan),
nn.LeakyReLU(inplace=True)
)
class SPADEEncoder(nn.Module):
def __init__(self, args):
super().__init__()
@KushajveerSingh
KushajveerSingh / st_blog_1
Created May 5, 2019 13:59
Load VGG model and pytorch essentials
if torch.cuda.is_available():
device = torch.device('cuda')
else:
raise Exception('GPU is not available')
# Load VGG19 features. We do not need the last linear layers,
# only CNN layers are needed
vgg = vgg19(pretrained=True).features
vgg = vgg.to(device)
# We don't want to train VGG
@KushajveerSingh
KushajveerSingh / st_blog_2
Last active May 5, 2019 14:07
Load images and show resutls
content_img = load_image(os.path.join(args.img_root, args.content_img), size=500)
content_img = content_img.to(device)
style_img = load_image(os.path.join(args.img_root, args.style_img))
style_img = style_img.to(device)
# Show content and style image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10))
ax1.imshow(im_convert(content_img))
ax2.imshow(im_convert(style_img))