Skip to content

Instantly share code, notes, and snippets.

@gs18113
Last active October 24, 2019 06:13
Show Gist options
  • Save gs18113/5f24d05104d11e4a0d928f8876eaff67 to your computer and use it in GitHub Desktop.
Save gs18113/5f24d05104d11e4a0d928f8876eaff67 to your computer and use it in GitHub Desktop.
from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
import tensorflow as tf
# Mostly from https://github.com/pytorch/examples/tree/master/super_resolution
def download_bsd300(dest="image_data"):
output_image_dir = join(dest, "BSDS300")
if not exists(output_image_dir):
makedirs(output_image_dir)
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
print("downloading url ", url)
data = urllib.request.urlopen(url)
file_path = join(dest, basename(url))
with open(file_path, 'wb') as f:
f.write(data.read())
print("Extracting data")
with tarfile.open(file_path) as tar:
for item in tar:
tar.extract(item, dest)
remove(file_path)
return join(output_image_dir, "images")
def get_image_from_file(filename, crop_size=256):
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image= tf.cast(image, tf.float32)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
offset_height = (image_height-crop_size) // 2
offset_width = (image_width-crop_size) // 2
original_image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size)
downsampled_image = tf.image.resize(original_image, [crop_size // 2, crop_size // 2])
# convert to 0~1 and change HWC to CHW
# (because the network accepts single channel.
# See model.py line 13 for better understanding.)
original_image = tf.transpose(original_image / 255.0, [2, 0, 1])
downsampled_image = tf.transpose(downsampled_image / 255.0, [2, 0, 1])
return downsampled_image, original_image
def get_training_set(upscale_factor):
root_dir = download_bsd300()
train_dir = join(root_dir, "train/*.jpg")
names = tf.data.Dataset.list_files(train_dir)
images = names.map(get_image_from_file)
return images
def get_test_set(upscale_factor):
root_dir = download_bsd300()
test_dir = join(root_dir, "test/*.jpg")
names = tf.data.Dataset.list_files(test_dir)
images = names.map(get_image_from_file)
return images
import tensorflow as tf
from tensorflow import keras
class ESPCN(keras.Model):
def __init__(self, upscale_factor):
super().__init__()
self.conv1 = keras.layers.Conv2D(64, 5, padding='same', activation='tanh', kernel_initializer='orthogonal')
self.conv2 = keras.layers.Conv2D(64, 3, padding='same', activation='tanh', kernel_initializer='orthogonal')
self.conv3 = keras.layers.Conv2D(64, 3, padding='same', activation='tanh', kernel_initializer='orthogonal')
self.conv4 = keras.layers.Conv2D((upscale_factor ** 2), 3, padding='same', activation='tanh', kernel_initializer='orthogonal')
self.upscale_factor = upscale_factor
def call(self, x):
x = tf.reshape(x, [-1, 128, 128, 1])
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = tf.nn.depth_to_space(x, self.upscale_factor)
return x
import argparse
import tensorflow as tf
from model import ESPCN
from data import get_training_set, get_test_set
import logging
from os.path import join
logging.basicConfig(level=logging.INFO, format='%(asctime)s [INFO] %(message)s')
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser()
parser.add_argument('-upscale_factor', default=2, type=int)
parser.add_argument('-num_epochs', default=100, type=int)
parser.add_argument('-batch_size', default=32, type=int)
parser.add_argument('-seed', default=123, type=int)
parser.add_argument('-lr', default=0.01, type=float)
parser.add_argument('-save_dir', default='saved_models', type=str)
parser.add_argument('-use_tpu', type=str2bool, nargs='?', default=False)
args = parser.parse_args()
tf.random.set_seed(args.seed)
# TPU objects
tpu_strategy = None
model = None
if args.use_tpu:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_host(cluster_resolver.master())
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
with tpu_strategy.scope():
model = ESPCN(args.upscale_factor)
else:
model = ESPCN(args.upscale_factor)
# Loss & optimizer
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=400,
decay_rate=0.99,
staircase=True)
optimizer = None
if args.use_tpu:
with tpu_strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
else:
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
# Dataset
train_dataset = None
test_dataset = None
if args.use_tpu:
with tpu_strategy.scope():
train_dataset = get_training_set(args.upscale_factor).shuffle(200).batch(args.batch_size)
train_dataset = tpu_strategy.experimental_distribute_dataset(train_dataset)
test_dataset = get_test_set(args.upscale_factor).batch(args.batch_size)
test_dataset = tpu_strategy.experimental_distribute_dataset(test_dataset)
else:
train_dataset = get_training_set(args.upscale_factor).shuffle(200).batch(args.batch_size)
test_dataset = get_test_set(args.upscale_factor).batch(args.batch_size)
# Train & test steps
train_step = None
test_step = None
if args.use_tpu:
with tpu_strategy.scope():
@tf.function
def train_step_tpu(dist_inputs):
def step_fn(inputs):
ds_image, image = inputs
with tf.GradientTape() as tape:
generated_image = model(ds_image)
loss_one = tf.reduce_sum(tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])), 1))
loss = loss_one * (1.0 / args.batch_size)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss_one
per_example_losses = tpu_strategy.experimental_run_v2(
step_fn, args=(dist_inputs, ))
mean_loss = tpu_strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
return mean_loss
train_step = train_step_tpu
else:
@tf.function
def train_step_normal(ds_image, image):
with tf.GradientTape() as tape:
generated_image = model(ds_image)
loss = tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
train_step = train_step_normal
if args.use_tpu:
with tpu_strategy.scope():
@tf.function
def test_step_tpu(dist_inputs):
def step_fn(inputs):
ds_image, image = inputs
generated_image = model(ds_image)
loss_one = tf.reduce_sum(tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])), 1))
return loss_one
per_example_losses = tpu_strategy.experimental_run_v2(
step_fn, args=(dist_inputs, ))
mean_loss = tpu_strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
return mean_loss
test_step = test_step_tpu
else:
@tf.function
def test_step_normal(ds_image, image):
generated_image = model(ds_image)
loss = tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])))
return loss
test_step = test_step_normal
best_model = 0
best_test_loss = 1000000
for epoch in range(args.num_epochs):
train_loss_sum = 0
train_cnt = 0
if args.use_tpu:
with tpu_strategy.scope():
for inputs in train_dataset:
train_loss_sum += train_step(inputs)
train_cnt += 1
else:
for ds_image, image in train_dataset:
train_loss_sum += train_step(ds_image, image)
train_cnt += 1
test_loss_sum = 0
test_cnt = 0
if args.use_tpu:
with tpu_strategy.scope():
for inputs in test_dataset:
test_loss_sum += test_step(inputs)
test_cnt += 1
else:
for test_ds_image, test_image in test_dataset:
test_loss_sum += test_step(test_ds_image, test_image)
test_cnt += 1
if best_test_loss > (test_loss_sum / test_cnt):
best_model = epoch
best_test_loss = (test_loss_sum / test_cnt)
save_path = join(args.save_dir, str(epoch))
tf.saved_model.save(model, save_path)
logging.info('epoch: %d, train_loss: %f, test_loss: %f' % (epoch+1, train_loss_sum / train_cnt, test_loss_sum / test_cnt))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment