Last active
October 24, 2019 06:13
-
-
Save gs18113/5f24d05104d11e4a0d928f8876eaff67 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os.path import exists, join, basename | |
from os import makedirs, remove | |
from six.moves import urllib | |
import tarfile | |
import tensorflow as tf | |
# Mostly from https://github.com/pytorch/examples/tree/master/super_resolution | |
def download_bsd300(dest="image_data"): | |
output_image_dir = join(dest, "BSDS300") | |
if not exists(output_image_dir): | |
makedirs(output_image_dir) | |
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" | |
print("downloading url ", url) | |
data = urllib.request.urlopen(url) | |
file_path = join(dest, basename(url)) | |
with open(file_path, 'wb') as f: | |
f.write(data.read()) | |
print("Extracting data") | |
with tarfile.open(file_path) as tar: | |
for item in tar: | |
tar.extract(item, dest) | |
remove(file_path) | |
return join(output_image_dir, "images") | |
def get_image_from_file(filename, crop_size=256): | |
image = tf.io.read_file(filename) | |
image = tf.image.decode_jpeg(image) | |
image= tf.cast(image, tf.float32) | |
image_height = tf.shape(image)[0] | |
image_width = tf.shape(image)[1] | |
offset_height = (image_height-crop_size) // 2 | |
offset_width = (image_width-crop_size) // 2 | |
original_image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size) | |
downsampled_image = tf.image.resize(original_image, [crop_size // 2, crop_size // 2]) | |
# convert to 0~1 and change HWC to CHW | |
# (because the network accepts single channel. | |
# See model.py line 13 for better understanding.) | |
original_image = tf.transpose(original_image / 255.0, [2, 0, 1]) | |
downsampled_image = tf.transpose(downsampled_image / 255.0, [2, 0, 1]) | |
return downsampled_image, original_image | |
def get_training_set(upscale_factor): | |
root_dir = download_bsd300() | |
train_dir = join(root_dir, "train/*.jpg") | |
names = tf.data.Dataset.list_files(train_dir) | |
images = names.map(get_image_from_file) | |
return images | |
def get_test_set(upscale_factor): | |
root_dir = download_bsd300() | |
test_dir = join(root_dir, "test/*.jpg") | |
names = tf.data.Dataset.list_files(test_dir) | |
images = names.map(get_image_from_file) | |
return images | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import keras | |
class ESPCN(keras.Model): | |
def __init__(self, upscale_factor): | |
super().__init__() | |
self.conv1 = keras.layers.Conv2D(64, 5, padding='same', activation='tanh', kernel_initializer='orthogonal') | |
self.conv2 = keras.layers.Conv2D(64, 3, padding='same', activation='tanh', kernel_initializer='orthogonal') | |
self.conv3 = keras.layers.Conv2D(64, 3, padding='same', activation='tanh', kernel_initializer='orthogonal') | |
self.conv4 = keras.layers.Conv2D((upscale_factor ** 2), 3, padding='same', activation='tanh', kernel_initializer='orthogonal') | |
self.upscale_factor = upscale_factor | |
def call(self, x): | |
x = tf.reshape(x, [-1, 128, 128, 1]) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.conv4(x) | |
x = tf.nn.depth_to_space(x, self.upscale_factor) | |
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import tensorflow as tf | |
from model import ESPCN | |
from data import get_training_set, get_test_set | |
import logging | |
from os.path import join | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s [INFO] %(message)s') | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
return True | |
elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
return False | |
else: | |
raise argparse.ArgumentTypeError('Boolean value expected.') | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-upscale_factor', default=2, type=int) | |
parser.add_argument('-num_epochs', default=100, type=int) | |
parser.add_argument('-batch_size', default=32, type=int) | |
parser.add_argument('-seed', default=123, type=int) | |
parser.add_argument('-lr', default=0.01, type=float) | |
parser.add_argument('-save_dir', default='saved_models', type=str) | |
parser.add_argument('-use_tpu', type=str2bool, nargs='?', default=False) | |
args = parser.parse_args() | |
tf.random.set_seed(args.seed) | |
# TPU objects | |
tpu_strategy = None | |
model = None | |
if args.use_tpu: | |
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() | |
tf.config.experimental_connect_to_host(cluster_resolver.master()) | |
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) | |
tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) | |
with tpu_strategy.scope(): | |
model = ESPCN(args.upscale_factor) | |
else: | |
model = ESPCN(args.upscale_factor) | |
# Loss & optimizer | |
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( | |
args.lr, | |
decay_steps=400, | |
decay_rate=0.99, | |
staircase=True) | |
optimizer = None | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) | |
else: | |
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) | |
# Dataset | |
train_dataset = None | |
test_dataset = None | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
train_dataset = get_training_set(args.upscale_factor).shuffle(200).batch(args.batch_size) | |
train_dataset = tpu_strategy.experimental_distribute_dataset(train_dataset) | |
test_dataset = get_test_set(args.upscale_factor).batch(args.batch_size) | |
test_dataset = tpu_strategy.experimental_distribute_dataset(test_dataset) | |
else: | |
train_dataset = get_training_set(args.upscale_factor).shuffle(200).batch(args.batch_size) | |
test_dataset = get_test_set(args.upscale_factor).batch(args.batch_size) | |
# Train & test steps | |
train_step = None | |
test_step = None | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
@tf.function | |
def train_step_tpu(dist_inputs): | |
def step_fn(inputs): | |
ds_image, image = inputs | |
with tf.GradientTape() as tape: | |
generated_image = model(ds_image) | |
loss_one = tf.reduce_sum(tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])), 1)) | |
loss = loss_one * (1.0 / args.batch_size) | |
gradients = tape.gradient(loss, model.trainable_variables) | |
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) | |
return loss_one | |
per_example_losses = tpu_strategy.experimental_run_v2( | |
step_fn, args=(dist_inputs, )) | |
mean_loss = tpu_strategy.reduce( | |
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None) | |
return mean_loss | |
train_step = train_step_tpu | |
else: | |
@tf.function | |
def train_step_normal(ds_image, image): | |
with tf.GradientTape() as tape: | |
generated_image = model(ds_image) | |
loss = tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256]))) | |
gradients = tape.gradient(loss, model.trainable_variables) | |
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) | |
return loss | |
train_step = train_step_normal | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
@tf.function | |
def test_step_tpu(dist_inputs): | |
def step_fn(inputs): | |
ds_image, image = inputs | |
generated_image = model(ds_image) | |
loss_one = tf.reduce_sum(tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256])), 1)) | |
return loss_one | |
per_example_losses = tpu_strategy.experimental_run_v2( | |
step_fn, args=(dist_inputs, )) | |
mean_loss = tpu_strategy.reduce( | |
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None) | |
return mean_loss | |
test_step = test_step_tpu | |
else: | |
@tf.function | |
def test_step_normal(ds_image, image): | |
generated_image = model(ds_image) | |
loss = tf.reduce_mean(tf.math.squared_difference(tf.reshape(generated_image, [-1, 256*256]), tf.reshape(image, [-1, 256*256]))) | |
return loss | |
test_step = test_step_normal | |
best_model = 0 | |
best_test_loss = 1000000 | |
for epoch in range(args.num_epochs): | |
train_loss_sum = 0 | |
train_cnt = 0 | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
for inputs in train_dataset: | |
train_loss_sum += train_step(inputs) | |
train_cnt += 1 | |
else: | |
for ds_image, image in train_dataset: | |
train_loss_sum += train_step(ds_image, image) | |
train_cnt += 1 | |
test_loss_sum = 0 | |
test_cnt = 0 | |
if args.use_tpu: | |
with tpu_strategy.scope(): | |
for inputs in test_dataset: | |
test_loss_sum += test_step(inputs) | |
test_cnt += 1 | |
else: | |
for test_ds_image, test_image in test_dataset: | |
test_loss_sum += test_step(test_ds_image, test_image) | |
test_cnt += 1 | |
if best_test_loss > (test_loss_sum / test_cnt): | |
best_model = epoch | |
best_test_loss = (test_loss_sum / test_cnt) | |
save_path = join(args.save_dir, str(epoch)) | |
tf.saved_model.save(model, save_path) | |
logging.info('epoch: %d, train_loss: %f, test_loss: %f' % (epoch+1, train_loss_sum / train_cnt, test_loss_sum / test_cnt)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment