Skip to content

Instantly share code, notes, and snippets.

@lantiga
Last active February 6, 2018 23:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lantiga/c7dbfb5c12b5bb75a78fe130c4623f7e to your computer and use it in GitHub Desktop.
Save lantiga/c7dbfb5c12b5bb75a78fe130c4623f7e to your computer and use it in GitHub Desktop.
CycleGAN pretrained
import torch
import torch.nn as nn
class ResnetBlock(nn.Module):
def __init__(self, dim):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
if __name__ == '__main__':
from PIL import Image
from torchvision import transforms
import sys
model_path = sys.argv[1]
image_path = sys.argv[2]
input_nc = 3
output_nc = 3
ngf = 64
n_blocks = 9
netG = ResnetGenerator(input_nc, output_nc, ngf, n_blocks=n_blocks)
netG.load_state_dict(torch.load(model_path))
netG.eval()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
])
img = Image.open(image_path)
img_t = preprocess(img)
input = torch.autograd.Variable(torch.unsqueeze(img_t, 0))
out = netG(input)
out_t = (out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
out_img.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment