Last active
February 5, 2018 04:24
-
-
Save unixpickle/856cef441e9f1623449687c5ea043634 to your computer and use it in GitHub Desktop.
Remove watermarks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Strip watermarks from some photos. | |
Expects a directory 'pairs' with watermarked and | |
unwatermarked photos. Watermarked photos have a _w suffix | |
in their filename. All names end with '.JPG'. | |
Expects another directory, 'remove_watermarks', with an | |
arbitrary set of photos. These photos are stripped of all | |
watermarks that were present in the training pairs. | |
""" | |
import os | |
import pickle | |
from PIL import Image | |
import numpy as np | |
import tensorflow as tf | |
def train_model(dir_path='pairs', model_path='model.pkl', num_steps=3000, max_lr=0.01): | |
"""Train a watermark removal model.""" | |
print('Loading data...') | |
inputs, outputs = map(load_images, list(zip(*pair_paths(dir_path)))) | |
print('Constructing graph...') | |
step_size = tf.placeholder(tf.float32, shape=()) | |
loss = tf.reduce_mean(tf.abs(model(tf.constant(inputs)) - outputs)) | |
total_loss = loss# + tf.losses.get_regularization_loss() | |
minimize = tf.train.AdamOptimizer(learning_rate=step_size).minimize(total_loss) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
print('Training...') | |
for i in range(num_steps): | |
cur_loss, _ = sess.run((loss, minimize), | |
feed_dict={step_size: (1 - (i / num_steps)) * max_lr}) | |
print('step %d: loss=%f' % (i, cur_loss)) | |
print('Saving model...') | |
with open(model_path, 'wb+') as out_file: | |
pickle.dump(sess.run(tf.trainable_variables()), out_file) | |
def apply_model(dir_path='remove_watermarks', model_path='model.pkl'): | |
"""Remove watermarks from a directory of images.""" | |
filenames = [os.path.join(dir_path, x) for x in os.listdir(dir_path) if x.endswith('.JPG')] | |
inputs = load_images(filenames) | |
outputs = tf.clip_by_value(model(tf.constant(inputs)), 0, 1) | |
with tf.Session() as sess: | |
with open(model_path, 'rb') as in_file: | |
new_vars = pickle.load(in_file) | |
for var, val in zip(tf.trainable_variables(), new_vars): | |
sess.run(tf.assign(var, val)) | |
outputs = sess.run(outputs) | |
for filename, raw_img in zip(filenames, outputs): | |
img = Image.fromarray((raw_img * 0xff).astype('uint8'), 'RGB') | |
img.save(filename, quality=100) | |
def model(images, name='model', reuse=False, l2_reg=0.00001): | |
"""Apply a de-watermarker to an image batch.""" | |
with tf.variable_scope(None, default_name=name, reuse=reuse, | |
regularizer=tf.contrib.layers.l2_regularizer(l2_reg)): | |
shape = tuple(x.value for x in images.get_shape()[1:]) | |
scales = tf.get_variable('scale', shape=shape, dtype=tf.float32, | |
initializer=tf.zeros_initializer()) | |
biases = tf.get_variable('bias', shape=shape, dtype=tf.float32, | |
initializer=tf.zeros_initializer()) | |
# Weight decay will bring the transformation down | |
# to the identity unless there's a reason not to. | |
return images * (scales + 1) + biases | |
def pair_paths(dir_path): | |
"""Get a list of (water, no_water) path tuples.""" | |
pairs = [] | |
for item in os.listdir(dir_path): | |
if item.endswith('_w.JPG'): | |
other_name = item[:-6] + '.JPG' | |
pairs.append((os.path.join(dir_path, item), os.path.join(dir_path, other_name))) | |
return pairs | |
def load_images(paths, width=1000, height=666): | |
"""Load a batch of images.""" | |
return np.array([np.array(Image.open(path).resize((width, height)).convert('RGB')) | |
for path in paths]).astype('float32') / 0xff | |
# pylint: disable=E1129 | |
# Comment this out to avoid re-training a new model. | |
with tf.Graph().as_default(): | |
train_model() | |
with tf.Graph().as_default(): | |
apply_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment