Skip to content

Instantly share code, notes, and snippets.

@jeremyadamsfisher
Created January 8, 2020 17:22
Show Gist options
  • Save jeremyadamsfisher/c7b3ca557487e5dce31c00e08a8917f0 to your computer and use it in GitHub Desktop.
Save jeremyadamsfisher/c7b3ca557487e5dce31c00e08a8917f0 to your computer and use it in GitHub Desktop.
Simple python function to run cycleGAN
img_horse = Image.open("./horse.png")
img_zebra = cyclegan(img, "./horse2zebra.pth", "horse2zebra")
img_zebra.save("./zebra.png")
torch
torchvision
Pillow==6.2
"""glue code to get cycleGAN working in the
context of a python program WITHOUT invoking
a subprocess"""
import typing as t
import os, sys
sys.path.append(os.path.join(os.getcwd(), "pytorch-CycleGAN-and-pix2pix"))
import torch
import io
import contextlib
from data.base_dataset import get_transform
from models.cycle_gan_model import CycleGANModel
from util.util import tensor2im
from PIL import Image
from argparse import Namespace
from pathlib import Path
from copy import deepcopy
OPT = Namespace(
aspect_ratio=1.0,
batch_size=1,
checkpoints_dir="./checkpoints",
crop_size=256,
dataroot=".",
dataset_mode="unaligned",
direction="AtoB",
display_id=-1,
display_winsize=256,
epoch="latest",
eval=False,
gpu_ids=[],
init_gain=0.02,
init_type="normal",
input_nc=3,
isTrain=False,
load_iter=0,
load_size=256,
max_dataset_size=float("inf"),
model="cycle_gan",
n_layers_D=3,
name=None,
ndf=64,
netD="basic",
netG="resnet_9blocks",
ngf=64,
no_dropout=True,
no_flip=True,
norm="instance",
ntest=float("inf"),
num_test=100,
num_threads=0,
output_nc=3,
phase="test",
preprocess="no_preprocessing",
results_dir="./results/",
serial_batches=True,
suffix="",
verbose=False,
)
class SingleImageDataset(torch.utils.data.Dataset):
"""dataset with precisely one image"""
def __init__(self, img, preprocess):
img = preprocess(img)
self.img = img
def __getitem__(self, i):
return self.img
def __len__(self):
return 1
def load_model(opt, fp):
model = CycleGANModel(opt).netG_A
model.load_state_dict(torch.load(fp))
return model
def cyclegan(img: Image,
model_fp:t.Union[Path,str],
model_name: str,
**kwargs) -> Image:
"""run cyclegan on a single Image
Arguments:
img: Pillow image to be run through cyclegan
model_fp: location of the model weights (.pth)
model_name: name of the model (specified with --name in
cyclegan command line interface)
**kwargs: passed to the cycleGAN opt object
"""
opt = deepcopy(OPT)
opt.__dict__.update(kwargs)
opt.name = model_name
if opt.verbose:
model = load_model(opt, model_fp)
else:
with contextlib.redirect_stdout(io.StringIO()):
model = load_model(opt, model_fp)
img = img.convert("RGB")
data_loader = torch.utils.data.DataLoader(
SingleImageDataset(img, get_transform(opt)), batch_size=1
)
data = next(iter(data_loader))
with torch.no_grad():
pred = model(data)
pred_arr = tensor2im(pred)
pred_img = Image.fromarray(pred_arr)
return pred_img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment