Skip to content

Instantly share code, notes, and snippets.

@CharStiles
Last active March 27, 2018 05:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CharStiles/e40cbade2534d5fb41c5ebd594d03cb7 to your computer and use it in GitHub Desktop.
Save CharStiles/e40cbade2534d5fb41c5ebd594d03cb7 to your computer and use it in GitHub Desktop.
'''
based off of test.py from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Thanks to Aman Tiwari for the help.
to run:
python ServerToProcessFaceSketch.py --dataroot ./darta --name face_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction AtoB --dataset_mode aligned --norm batch
'''
import os
import io
import time
import zmq
import random
import sys
import base64
import util.util as util
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import numpy as np
from PIL import Image
from scipy.misc import imresize
from torch.autograd import Variable
from torchvision import transforms
prepare = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
port = "8080"
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:%s" % port)
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
f = open("./darta/test/img.jpg",'rb')
ff = f.read()
f.close()
def forward(model, inp):
tensor = prepare(inp)
inp_var = Variable(tensor.cuda(), volatile=True) # doesnt save gradients
inp_var = inp_var.unsqueeze(0) # batch size is 1
pred = model.netG(inp_var)
return util.tensor2im(pred.data)
while True:
msg = socket.recv()
img = Image.frombytes('RGB',(480,320),msg).resize((256, 256), Image.BICUBIC)
result = forward(model, img)
im = Image.fromarray(result).resize((480,320))
im.save('is_art2.png')
time.sleep (0.5)
socket.send(np.array(im))
time.sleep(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment