Created
January 8, 2020 17:22
-
-
Save jeremyadamsfisher/c7b3ca557487e5dce31c00e08a8917f0 to your computer and use it in GitHub Desktop.
Simple python function to run cycleGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
img_horse = Image.open("./horse.png") | |
img_zebra = cyclegan(img, "./horse2zebra.pth", "horse2zebra") | |
img_zebra.save("./zebra.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch | |
torchvision | |
Pillow==6.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""glue code to get cycleGAN working in the | |
context of a python program WITHOUT invoking | |
a subprocess""" | |
import typing as t | |
import os, sys | |
sys.path.append(os.path.join(os.getcwd(), "pytorch-CycleGAN-and-pix2pix")) | |
import torch | |
import io | |
import contextlib | |
from data.base_dataset import get_transform | |
from models.cycle_gan_model import CycleGANModel | |
from util.util import tensor2im | |
from PIL import Image | |
from argparse import Namespace | |
from pathlib import Path | |
from copy import deepcopy | |
OPT = Namespace( | |
aspect_ratio=1.0, | |
batch_size=1, | |
checkpoints_dir="./checkpoints", | |
crop_size=256, | |
dataroot=".", | |
dataset_mode="unaligned", | |
direction="AtoB", | |
display_id=-1, | |
display_winsize=256, | |
epoch="latest", | |
eval=False, | |
gpu_ids=[], | |
init_gain=0.02, | |
init_type="normal", | |
input_nc=3, | |
isTrain=False, | |
load_iter=0, | |
load_size=256, | |
max_dataset_size=float("inf"), | |
model="cycle_gan", | |
n_layers_D=3, | |
name=None, | |
ndf=64, | |
netD="basic", | |
netG="resnet_9blocks", | |
ngf=64, | |
no_dropout=True, | |
no_flip=True, | |
norm="instance", | |
ntest=float("inf"), | |
num_test=100, | |
num_threads=0, | |
output_nc=3, | |
phase="test", | |
preprocess="no_preprocessing", | |
results_dir="./results/", | |
serial_batches=True, | |
suffix="", | |
verbose=False, | |
) | |
class SingleImageDataset(torch.utils.data.Dataset): | |
"""dataset with precisely one image""" | |
def __init__(self, img, preprocess): | |
img = preprocess(img) | |
self.img = img | |
def __getitem__(self, i): | |
return self.img | |
def __len__(self): | |
return 1 | |
def load_model(opt, fp): | |
model = CycleGANModel(opt).netG_A | |
model.load_state_dict(torch.load(fp)) | |
return model | |
def cyclegan(img: Image, | |
model_fp:t.Union[Path,str], | |
model_name: str, | |
**kwargs) -> Image: | |
"""run cyclegan on a single Image | |
Arguments: | |
img: Pillow image to be run through cyclegan | |
model_fp: location of the model weights (.pth) | |
model_name: name of the model (specified with --name in | |
cyclegan command line interface) | |
**kwargs: passed to the cycleGAN opt object | |
""" | |
opt = deepcopy(OPT) | |
opt.__dict__.update(kwargs) | |
opt.name = model_name | |
if opt.verbose: | |
model = load_model(opt, model_fp) | |
else: | |
with contextlib.redirect_stdout(io.StringIO()): | |
model = load_model(opt, model_fp) | |
img = img.convert("RGB") | |
data_loader = torch.utils.data.DataLoader( | |
SingleImageDataset(img, get_transform(opt)), batch_size=1 | |
) | |
data = next(iter(data_loader)) | |
with torch.no_grad(): | |
pred = model(data) | |
pred_arr = tensor2im(pred) | |
pred_img = Image.fromarray(pred_arr) | |
return pred_img |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment