Created
July 30, 2021 03:55
-
-
Save e96031413/4a7becaffd981e757feaa161e4e72da8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from torchvision.utils import save_image | |
from model import Model | |
output_folder = 'output_image' | |
if os.path.isfile(output_folder): | |
print("output folder exists!") | |
else: | |
os.makedirs(output_folder) | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
trans = transforms.Compose([transforms.ToTensor(), | |
normalize]) | |
def denorm(tensor, device): | |
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device) | |
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device) | |
res = torch.clamp(tensor * std + mean, 0, 1) | |
return res | |
def main(): | |
parser = argparse.ArgumentParser(description='AdaIN Style Transfer by Pytorch') | |
parser.add_argument('--content', '-c', type=str, default='/root/notebooks/XXXXXXXXXX.jpg', | |
help='Content image path e.g. content.jpg') | |
parser.add_argument('--style', '-s', type=str, default='/root/notebooks/KKKKKKKKKKKK.jpg', | |
help='Style image path e.g. image.jpg') | |
parser.add_argument('--output_name', '-o', type=str, default=None, | |
help='Output path for generated image, no need to add ext, e.g. out') | |
parser.add_argument('--alpha', '-a', type=float, default=1, | |
help='alpha control the fusion degree in Adain') | |
parser.add_argument('--gpu', '-g', type=int, default=0, | |
help='GPU ID(nagative value indicate CPU)') | |
parser.add_argument('--model_state_path', type=str, default='result/model_state/20_epoch.pth', | |
help='save directory for result and loss') | |
args = parser.parse_args() | |
# set device on GPU if available, else CPU | |
if torch.cuda.is_available() and args.gpu >= 0: | |
device = torch.device(f'cuda:{args.gpu}') | |
print(f'# CUDA available: {torch.cuda.get_device_name(0)}') | |
else: | |
device = 'cpu' | |
# set model | |
model = Model() | |
if args.model_state_path is not None: | |
model.load_state_dict(torch.load(args.model_state_path, map_location=lambda storage, loc: storage)) | |
model = model.to(device) | |
content_dict = {'filename': '/path/to/your/content/image.jpg', | |
'filename2': '/path/to/your/content/image2.jpg', | |
} | |
for name, path in content_dict.items(): | |
c = Image.open(path) | |
s = Image.open(args.style) | |
c_tensor = trans(c).unsqueeze(0).to(device) | |
s_tensor = trans(s).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
out = model.generate(c_tensor, s_tensor, args.alpha) | |
out = denorm(out, device) | |
args.output_name = f'{name}' | |
save_image(out, f'{args.output_name}.jpg', nrow=1) | |
o = Image.open(f'{args.output_name}.jpg') | |
demo = Image.new('RGB', (c.width * 2, c.height)) | |
o = o.resize(c.size) | |
s = s.resize((i // 4 for i in c.size)) | |
demo.paste(c, (0, 0)) | |
demo.paste(o, (c.width, 0)) | |
demo.paste(s, (c.width, c.height - s.height)) | |
demo.save(f'{output_folder}/{args.output_name}_style_transfer_demo.jpg', quality=95) | |
#o.paste(s, (0, o.height - s.height)) | |
#o.save(f'{args.output_name}_with_style_image.jpg', quality=95) | |
print(f'result saved into files starting with {args.output_name}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment