Skip to content

Instantly share code, notes, and snippets.

@AlexanderFabisch
Created August 3, 2017 09:03
Show Gist options
  • Save AlexanderFabisch/da3a5b5cda73bbfecfb6bf1390248c5b to your computer and use it in GitHub Desktop.
Save AlexanderFabisch/da3a5b5cda73bbfecfb6bf1390248c5b to your computer and use it in GitHub Desktop.
pix2pix image generation
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('patch_size', type=int, nargs=2, action="store", help="Patch size for D")
parser.add_argument('--backend', type=str, default="theano", help="theano or tensorflow")
parser.add_argument('--generator', type=str, default="upsampling", help="upsampling or deconv")
parser.add_argument('--dset', type=str, default="facades", help="facades")
parser.add_argument('--batch_size', default=4, type=int, help='Batch size')
parser.add_argument('--do_plot', action="store_true", help="Debugging plot")
parser.add_argument('--bn_mode', default=2, type=int, help="Batch norm mode")
parser.add_argument('--img_dim', default=64, type=int, help="Image width == height")
parser.add_argument('--use_mbd', action="store_true", help="Whether to use minibatch discrimination")
parser.add_argument('--use_label_smoothing', action="store_true", help="Whether to smooth the positive labels when training D")
parser.add_argument('--label_flipping', default=0, type=float, help="Probability (0 to 1.) to flip the labels when training D")
args = parser.parse_args()
# Set the backend by modifying the env variable
if args.backend == "theano":
os.environ["KERAS_BACKEND"] = "theano"
elif args.backend == "tensorflow":
os.environ["KERAS_BACKEND"] = "tensorflow"
# Import the backend
import keras.backend as K
# manually set dim ordering otherwise it is not changed
if args.backend == "theano":
image_dim_ordering = "th"
K.set_image_dim_ordering(image_dim_ordering)
elif args.backend == "tensorflow":
image_dim_ordering = "tf"
K.set_image_dim_ordering(image_dim_ordering)
import test
# Set default params
d_params = {"dset": args.dset,
"generator": args.generator,
"batch_size": args.batch_size,
"model_name": "CNN",
"do_plot": args.do_plot,
"image_dim_ordering": image_dim_ordering,
"bn_mode": args.bn_mode,
"img_dim": args.img_dim,
"use_label_smoothing": args.use_label_smoothing,
"label_flipping": args.label_flipping,
"patch_size": args.patch_size,
"use_mbd": args.use_mbd
}
# Launch training
test.test(**d_params)
import os
import sys
import time
import numpy as np
import models
from keras.optimizers import Adam, SGD
import keras.backend as K
# Utils
sys.path.append("../utils")
import general_utils
import data_utils
def test(**kwargs):
"""
Train model
Load the whole train data in memory for faster operations
args: **kwargs (dict) keyword arguments that specify the model hyperparameters
"""
# Roll out the parameters
batch_size = kwargs["batch_size"]
generator = kwargs["generator"]
image_dim_ordering = kwargs["image_dim_ordering"]
img_dim = kwargs["img_dim"]
patch_size = kwargs["patch_size"]
bn_mode = kwargs["bn_mode"]
dset = kwargs["dset"]
use_mbd = kwargs["use_mbd"]
# Load and rescale data
X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
dset, image_dim_ordering)
img_dim = X_full_train.shape[-3:]
# Get the number of non overlapping patch
nb_patch, _ = data_utils.get_nb_patch(img_dim, patch_size, image_dim_ordering)
try:
# Load generator model
generator_model = models.load("generator_unet_%s" % generator,
img_dim,
nb_patch,
bn_mode,
use_mbd,
batch_size)
opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
generator_model.compile(loss='mae', optimizer=opt_discriminator)
generator_model.load_weights("../../models/CNN.backup/gen_weights_epoch395.h5")
print("Start generating")
i = 0
for X_full_batch, X_sketch_batch in data_utils.gen_batch(X_full_train, X_sketch_train, batch_size):
X_gen_target, X_gen = next(data_utils.gen_batch(
X_full_train, X_sketch_train, batch_size))
y = generator_model.predict(X_gen, verbose=1)
np.save("batch_%02d.npy" % i, y)
i += 1
except KeyboardInterrupt:
pass
@AlexanderFabisch
Copy link
Author

generated

@AlexanderFabisch
Copy link
Author

current_batch_validation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment