Created
August 3, 2017 09:03
-
-
Save AlexanderFabisch/da3a5b5cda73bbfecfb6bf1390248c5b to your computer and use it in GitHub Desktop.
pix2pix image generation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Train model') | |
parser.add_argument('patch_size', type=int, nargs=2, action="store", help="Patch size for D") | |
parser.add_argument('--backend', type=str, default="theano", help="theano or tensorflow") | |
parser.add_argument('--generator', type=str, default="upsampling", help="upsampling or deconv") | |
parser.add_argument('--dset', type=str, default="facades", help="facades") | |
parser.add_argument('--batch_size', default=4, type=int, help='Batch size') | |
parser.add_argument('--do_plot', action="store_true", help="Debugging plot") | |
parser.add_argument('--bn_mode', default=2, type=int, help="Batch norm mode") | |
parser.add_argument('--img_dim', default=64, type=int, help="Image width == height") | |
parser.add_argument('--use_mbd', action="store_true", help="Whether to use minibatch discrimination") | |
parser.add_argument('--use_label_smoothing', action="store_true", help="Whether to smooth the positive labels when training D") | |
parser.add_argument('--label_flipping', default=0, type=float, help="Probability (0 to 1.) to flip the labels when training D") | |
args = parser.parse_args() | |
# Set the backend by modifying the env variable | |
if args.backend == "theano": | |
os.environ["KERAS_BACKEND"] = "theano" | |
elif args.backend == "tensorflow": | |
os.environ["KERAS_BACKEND"] = "tensorflow" | |
# Import the backend | |
import keras.backend as K | |
# manually set dim ordering otherwise it is not changed | |
if args.backend == "theano": | |
image_dim_ordering = "th" | |
K.set_image_dim_ordering(image_dim_ordering) | |
elif args.backend == "tensorflow": | |
image_dim_ordering = "tf" | |
K.set_image_dim_ordering(image_dim_ordering) | |
import test | |
# Set default params | |
d_params = {"dset": args.dset, | |
"generator": args.generator, | |
"batch_size": args.batch_size, | |
"model_name": "CNN", | |
"do_plot": args.do_plot, | |
"image_dim_ordering": image_dim_ordering, | |
"bn_mode": args.bn_mode, | |
"img_dim": args.img_dim, | |
"use_label_smoothing": args.use_label_smoothing, | |
"label_flipping": args.label_flipping, | |
"patch_size": args.patch_size, | |
"use_mbd": args.use_mbd | |
} | |
# Launch training | |
test.test(**d_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import time | |
import numpy as np | |
import models | |
from keras.optimizers import Adam, SGD | |
import keras.backend as K | |
# Utils | |
sys.path.append("../utils") | |
import general_utils | |
import data_utils | |
def test(**kwargs): | |
""" | |
Train model | |
Load the whole train data in memory for faster operations | |
args: **kwargs (dict) keyword arguments that specify the model hyperparameters | |
""" | |
# Roll out the parameters | |
batch_size = kwargs["batch_size"] | |
generator = kwargs["generator"] | |
image_dim_ordering = kwargs["image_dim_ordering"] | |
img_dim = kwargs["img_dim"] | |
patch_size = kwargs["patch_size"] | |
bn_mode = kwargs["bn_mode"] | |
dset = kwargs["dset"] | |
use_mbd = kwargs["use_mbd"] | |
# Load and rescale data | |
X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data( | |
dset, image_dim_ordering) | |
img_dim = X_full_train.shape[-3:] | |
# Get the number of non overlapping patch | |
nb_patch, _ = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering) | |
try: | |
# Load generator model | |
generator_model = models.load("generator_unet_%s" % generator, | |
img_dim, | |
nb_patch, | |
bn_mode, | |
use_mbd, | |
batch_size) | |
opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) | |
generator_model.compile(loss='mae', optimizer=opt_discriminator) | |
generator_model.load_weights("../../models/CNN.backup/gen_weights_epoch395.h5") | |
print("Start generating") | |
i = 0 | |
for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size): | |
X_gen_target, X_gen = next(data_utils.gen_batch( | |
X_full_train, X_sketch_train, batch_size)) | |
y = generator_model.predict(X_gen, verbose=1) | |
np.save("batch_%02d.npy" % i, y) | |
i += 1 | |
except KeyboardInterrupt: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment