Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
based off of from
Thanks to Aman Tiwari for the help.
to run:
python --dataroot ./darta --name face_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction AtoB --dataset_mode aligned --norm batch
import os
import io
import time
import zmq
import random
import sys
import base64
import util.util as util
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import numpy as np
from PIL import Image
from scipy.misc import imresize
from torch.autograd import Variable
from torchvision import transforms
prepare = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
port = "8080"
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:%s" % port)
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
f = open("./darta/test/img.jpg",'rb')
ff =
def forward(model, inp):
tensor = prepare(inp)
inp_var = Variable(tensor.cuda(), volatile=True) # doesnt save gradients
inp_var = inp_var.unsqueeze(0) # batch size is 1
pred = model.netG(inp_var)
return util.tensor2im(
while True:
msg = socket.recv()
img = Image.frombytes('RGB',(480,320),msg).resize((256, 256), Image.BICUBIC)
result = forward(model, img)
im = Image.fromarray(result).resize((480,320))'is_art2.png')
time.sleep (0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment