Skip to content

Instantly share code, notes, and snippets.

@CasiaFan
Created November 22, 2018 08:03
Show Gist options
  • Save CasiaFan/6cf3e0f4c879b1b58c780d959577ef46 to your computer and use it in GitHub Desktop.
Save CasiaFan/6cf3e0f4c879b1b58c780d959577ef46 to your computer and use it in GitHub Desktop.
An example of using tensorflow hub for image generation with BigGAN
import tensorflow as tf
import tensorflow_hub as hub
import cv2
import numpy as np
from scipy.stats import truncnorm
# MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN
MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN
# MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN
tf.reset_default_graph()
module = hub.Module(MODULE_PATH)
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
for k, v in module.get_input_info_dict().items()}
output = module(inputs)
print("Inputs: \n", '\n'.join('{}:{}'.format(*kv) for kv in inputs.items()))
input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']
dim_z = input_z.shape.as_list()[1]
vocab_size = input_y.shape.as_list()[1]
# generate random noise for image generation
def truncated_z_sample(batch_size, truncation=1., seed=None):
state = None if seed is None else np.random.RandomState(seed)
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
return truncation * values
def one_hot(label, vocab_size=vocab_size):
label = np.asarray(label)
if len(label.shape) <= 1:
label = np.asarray(label)
if len(label.shape) == 0:
label = np.asarray([label])
assert len(label.shape) == 1
num = label.shape[0]
output = np.zeros((num, vocab_size), dtype=np.float32)
output[np.arange(num), label] = 1
label = output
assert len(label.shape) == 2
return label
def sample(sess, noise, label, truncation=1., batch_size=8, vocab_size=vocab_size):
noise = np.asarray(noise)
label = np.asarray(label)
num = noise.shape[0]
if len(label.shape) == 0:
label = np.asarray([label] * num)
if label.shape[0] != num:
raise ValueError('Got # noise samples ({}) != # label samples ({})'
.format(noise.shape[0], label.shape[0]))
label = one_hot(label, vocab_size)
ims = []
for batch_start in range(0, num, batch_size):
s = slice(batch_start, min(num, batch_start + batch_size))
feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}
ims.append(sess.run(output, feed_dict=feed_dict))
ims = np.concatenate(ims, axis=0)
assert ims.shape[0] == num
ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
ims = np.uint8(ims)
return ims
initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)
num_samples = 5 # control number of generated images
truncation = 0.6 # contral fidelity
noise_seed = 0
category = "283) Persian cat" # which category to generate
z = truncated_z_sample(num_samples, truncation, noise_seed)
y = int(category.split(')')[0])
ims = sample(sess, z, y, truncation=truncation)
grid = np.concatenate(ims, axis=1)
grid = np.asarray(grid, np.uint8)
grid = cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)
cv2.imshow("res", grid)
cv2.waitKey(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment