-
-
Save adrianlzt/68cdfd78149cbfcc0896a73244e4ecff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import os | |
from matplotlib import gridspec | |
import matplotlib.pylab as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
from tensorflow.python.keras.backend import set_session | |
config = tf.compat.v1.ConfigProto() | |
config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU | |
config.log_device_placement = True # to log device placement (on which device the operation ran) | |
sess = tf.compat.v1.Session(config=config) | |
set_session(sess) | |
print("TF Version: ", tf.__version__) | |
print("TF-Hub version: ", hub.__version__) | |
print("Eager mode enabled: ", tf.executing_eagerly()) | |
print("GPU available: ", tf.test.is_gpu_available()) | |
# @title Define image loading and visualization functions { display-mode: "form" } | |
def crop_center(image): | |
"""Returns a cropped square image.""" | |
shape = image.shape | |
new_shape = min(shape[1], shape[2]) | |
offset_y = max(shape[1] - shape[2], 0) // 2 | |
offset_x = max(shape[2] - shape[1], 0) // 2 | |
image = tf.image.crop_to_bounding_box( | |
image, offset_y, offset_x, new_shape, new_shape) | |
return image | |
@functools.lru_cache(maxsize=None) | |
def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True): | |
"""Loads and preprocesses images.""" | |
# Cache image file locally. | |
image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) | |
# Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. | |
img = plt.imread(image_path).astype(np.float32)[np.newaxis, ...] | |
if img.max() > 1.0: | |
img = img / 255. | |
if len(img.shape) == 3: | |
img = tf.stack([img, img, img], axis=-1) | |
img = crop_center(img) | |
img = tf.image.resize(img, image_size, preserve_aspect_ratio=True) | |
return img | |
def show_n(images, titles=('',)): | |
n = len(images) | |
image_sizes = [image.shape[1] for image in images] | |
w = (image_sizes[0] * 6) // 320 | |
plt.figure(figsize=(w * n, w)) | |
gs = gridspec.GridSpec(1, n, width_ratios=image_sizes) | |
for i in range(n): | |
plt.subplot(gs[i]) | |
plt.imshow(images[i][0], aspect='equal') | |
plt.axis('off') | |
plt.title(titles[i] if len(titles) > i else '') | |
plt.show() | |
# @title Load example images { display-mode: "form" } | |
content_image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/fd/Golden_Gate_Bridge_from_Battery_Spencer.jpg/640px-Golden_Gate_Bridge_from_Battery_Spencer.jpg' # @param {type:"string"} | |
content_image_url = 'https://i.stack.imgur.com/ZME5U.jpg' # @param {type:"string"} | |
style_image_url = 'https://upload.wikimedia.org/wikipedia/commons/0/0a/The_Great_Wave_off_Kanagawa.jpg' # @param {type:"string"} | |
output_image_size = 384 # @param {type:"integer"} | |
# The content image size can be arbitrary. | |
content_img_size = (output_image_size, output_image_size) | |
# The style prediction model was trained with image size 256 and it's the | |
# recommended image size for the style image (though, other sizes work as | |
# well but will lead to different results). | |
style_img_size = (256, 256) # Recommended to keep it at 256. | |
content_image = load_image(content_image_url, content_img_size) | |
style_image = load_image(style_image_url, style_img_size) | |
style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME') | |
#show_n([content_image, style_image], ['Content image', 'Style image']) | |
# Load TF-Hub module. | |
hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2' | |
hub_module = hub.load(hub_handle) | |
outputs = hub_module(content_image, style_image) | |
stylized_image = outputs[0] | |
# Stylize content image with given style image. | |
# This is pretty fast within a few milliseconds on a GPU. | |
outputs = hub_module(tf.constant(content_image), tf.constant(style_image)) | |
stylized_image = outputs[0] | |
# Visualize input images and the generated stylized image. | |
show_n([content_image, style_image, stylized_image], titles=['Original content image', 'Style image', 'Stylized image']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment