Skip to content

Instantly share code, notes, and snippets.

@kosuke1701
Last active May 5, 2021 22:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kosuke1701/4e2ac722bb0c4af9cbf9acfec3d91c3f to your computer and use it in GitHub Desktop.
Save kosuke1701/4e2ac722bb0c4af9cbf9acfec3d91c3f to your computer and use it in GitHub Desktop.
Wrapper class which use a pre-trained [Pix2PixHD](https://github.com/NVIDIA/pix2pixHD) checkpoint to convert image. Set `PIX2PIX_DIR` variable to the directory of a cloned Pix2PixHD project. Follow README of the Pix2PixHD project to set up dependencies.
# Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
# Copyright (C) 2021 kosuke1701
# BSD License. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
# OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from argparse import Namespace
import os
import sys
from PIL import Image
import torch
sys.path.append(os.environ["PIX2PIX_DIR"])
from models.models import create_model
from options.test_options import TestOptions
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
import util
class Pix2PixHD_Converter:
def __init__(self, arg_list):
parser = TestOptions()
if not parser.initialized:
parser.initialize()
opt = parser.parser.parse_args(arg_list)
opt.isTrain = False
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# set gpu ids
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
model = create_model(opt)
model = create_model(opt)
if opt.data_type == 16:
model.half()
elif opt.data_type == 8:
model.type(torch.uint8)
if opt.verbose:
print(model)
self.opt = opt
self.model = model
def process(self, A_path, B_path=None, inst_path=None):
### input A (label maps)
A = Image.open(A_path)
original_size = A.size
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))
data = {'label': A_tensor.unsqueeze(0), 'inst': inst_tensor.unsqueeze(0) if not isinstance(inst_tensor, int) else None, 'image': B_tensor.unsqueeze(0) if not isinstance(B_tensor, int) else None,
'feat': feat_tensor.unsqueeze(0) if not isinstance(feat_tensor, int) else None, 'path': A_path}
if self.opt.data_type == 16:
data['label'] = data['label'].half()
data['inst'] = data['inst'].half()
elif self.opt.data_type == 8:
data['label'] = data['label'].uint8()
data['inst'] = data['inst'].uint8()
with torch.no_grad():
generated = self.model.inference(data['label'], data['inst'], data['image'])
generated_image = util.util.tensor2im(generated.data[0])
generated_image = Image.fromarray(generated_image).resize(original_size)
return generated_image
if __name__=="__main__":
# NOTE: Set PIX2PIX_DIR environment variable!
option_str = "--name train_danboo_region_val --label_nc 0 --no_instance --resize_or_crop none --checkpoints_dir pix2pixHD/checkpoints"
converter = Pix2PixHD_Converter(option_str.split(" "))
gen_image = converter.process("20210421.png") # PIL Image
gen_image.save("20210421_skeleton.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment