Created
May 31, 2017 15:38
-
-
Save hiepph/125ebe0795ca9f5bfac3328b4d604928 to your computer and use it in GitHub Desktop.
pytorch-CycleGAN-and-pix2pix single image prediction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix | |
from torch.autograd import Variable | |
from torchvision import transforms | |
from PIL import Image | |
from options.test_options import TestOptions | |
from models.models import create_model | |
import util.util as util | |
def main(): | |
# Parse argument options | |
opt = TestOptions().parse() | |
# Some defaults values for generating purpose | |
opt.nThreads = 1 | |
opt.batchSize = 1 | |
opt.serial_batches = True | |
opt.use_dropout = True | |
opt.align_data = True | |
opt.model = 'one_direction_test' | |
opt.which_model_netG = 'unet_256' | |
opt.which_direction = 'AtoB' | |
# Load model | |
model = create_model(opt) | |
# Load image | |
real = Image.open('./real_A.jpg') | |
preprocess = transforms.Compose([ | |
transforms.Scale(opt.loadSize), | |
transforms.RandomCrop(opt.fineSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5)), | |
]) | |
# Load input | |
input_A = preprocess(real).unsqueeze_(0) | |
model.input_A.resize_(input_A.size()).copy_(input_A) | |
# Forward (model.real_A) through G and produce output (model.fake_B) | |
model.test() | |
# Convert image to numpy array | |
fake = util.tensor2im(model.fake_B.data) | |
# Save image | |
util.save_image(fake, './fake_B.png') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment