Skip to content

Instantly share code, notes, and snippets.

@ptrcarta
Created April 12, 2020 16:36
Show Gist options
  • Save ptrcarta/7fe1424649462df2b7b84560bd0a1d67 to your computer and use it in GitHub Desktop.
Save ptrcarta/7fe1424649462df2b7b84560bd0a1d67 to your computer and use it in GitHub Desktop.
dcgan tpu
import os
import tensorflow as tf
from tensorflow import keras
from timeit import default_timer as timer
import matplotlib.pyplot as plt
import numpy as np
IMAGE_DATA_FORMAT = 'channels_last'
keras.backend.set_image_data_format(IMAGE_DATA_FORMAT)
batchnorm_axis = 3 if IMAGE_DATA_FORMAT == 'channels_last' else 1
if 'TPU_NAME' in os.environ:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
elif 'COLAB_GPU' in os.environ and int(os.environ['COLAB_GPU']) > 0:
strategy = tf.distribute.MirroredStrategy()
else:
strategy = tf.distribute.OneDeviceStrategy(device='/cpu:0')
BATCH_SIZE = 16
GLOBAL_BATCH_SIZE = BATCH_SIZE*strategy.num_replicas_in_sync
EPOCHS = 10
ADAM_LR=0.0002
ADAM_BETA1 = 0.5
ADAM_BETA2 = 0.999
nz = 256
def D():
return keras.Sequential([
keras.layers.Conv2D(16, 4, strides=4),
keras.layers.ReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2D(128, 3, strides=1),
keras.layers.ReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2D(256, 3, strides=1),
keras.layers.ReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2D(256, 3, strides=3),
keras.layers.ReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2D(1, (7,6), strides=1),
])
def G():
return keras.Sequential([
keras.layers.Reshape((1,1,nz))\
if IMAGE_DATA_FORMAT=='channels_last'\
else keras.layers.Reshape((nz,1,1)),
keras.layers.Conv2DTranspose(2048, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(1024, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(512, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(256, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(128, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(64, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(32, 2, strides=2),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(32, 3, strides=1),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2DTranspose(32, 3, strides=1),
keras.layers.LeakyReLU(),
keras.layers.BatchNormalization(axis=batchnorm_axis),
keras.layers.Conv2D(3, 1, strides=1),
keras.layers.Activation(tf.sigmoid),
keras.layers.Cropping2D(((0,23),(0,43)))
])
class TimedInputs:
def __init__(self, secs):
self._count = 0
self._start_time = None
self._last_time = None
self._secs = secs
def __enter__(self):
self._start_time = timer()
return self
def _print_rate(self):
print(f'rate: {self._count/(self._last_time - self._start_time):.2f}')
def count(self, num):
self._last_time = timer()
self._count += num
if self._last_time - self._start_time > self._secs:
self._print_rate()
self._start_time = self._last_time
self._count = 0
def __exit__(self, typ, value, tb):
self._last_time = timer()
self._print_rate()
feature_description = {
'name': tf.io.FixedLenFeature([], tf.string, default_value=''),
'img': tf.io.FixedLenFeature([], tf.string, default_value='')
}
def make_tfrecords():
imgs_file = 10000
def name_and_content(fn):
return fn, tf.io.read_file(fn)
ds = tf.data.Dataset.list_files(
'data/img_align_celeba/*jpg', shuffle=False).map(name_and_content)
for i, (fn, img) in enumerate(ds):
if i % imgs_file == 0:
print('img ', i, '/', 202599)
w = tf.io.TFRecordWriter(
f'gs://pietor-euw4/celeba/celeba_{i//imgs_file:03d}.tfr')
name_f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[fn.numpy()]))
img_f = tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.numpy()]))
feature = {'name': name_f, 'img':img_f}
features=tf.train.Features(feature=feature)
example_proto = tf.train.Example(features=features)
w.write(example_proto.SerializeToString())
if ((i + 1) % imgs_file == 0) or ((202599 - 1) == i):
w.close()
print('written')
def parse_examples(ex_serialized):
return tf.io.parse_single_example(ex_serialized, feature_description)
def decode_image(image):
img = tf.io.decode_image(image, expand_animations=False)
img = tf.cast(img[::2,::2,:], tf.float32)/255.
if IMAGE_DATA_FORMAT == 'channels_first':
img = tf.transpose(img, [2,0,1])
return img
def get_dataset_local():
return tf.data.Dataset.list_files(
'data/img_align_celeba/*jpg', shuffle=False).map(
tf.io.read_file).map(decode_image
)
def drop_names(d):
return d['img']
def get_dataset_gcs():
return tf.data.TFRecordDataset(
tf.io.matching_files('gs://pietor-euw4/celeba/celeba_*')
).prefetch(1024
).map(parse_examples
).map(drop_names
).map(decode_image
)
if __name__ == '__main__':
with strategy.scope():
d = D()
g = G()
optD = keras.optimizers.Adam(ADAM_LR, ADAM_BETA1, ADAM_BETA2)
optG = keras.optimizers.Adam(ADAM_LR, ADAM_BETA1, ADAM_BETA2)
@tf.function
def train_step(real_inputs):
zs = tf.random.normal(
[GLOBAL_BATCH_SIZE//strategy.num_replicas_in_sync, nz])
with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
fake_inputs = g(zs)
## Train D
d_out_real = d(real_inputs)
d_out_fake = d(fake_inputs)
loss_d_real = keras.losses.binary_crossentropy(
tf.ones_like(d_out_real), d_out_real, from_logits=True)
loss_d_fake = keras.losses.binary_crossentropy(
tf.zeros_like(d_out_fake), d_out_fake, from_logits=True)
loss_d = loss_d_real + loss_d_fake
grad_d = d_tape.gradient(loss_d, d.trainable_variables)
optD.apply_gradients(zip(grad_d, d.trainable_variables))
ddd = tf.linalg.norm(list(tf.nest.map_structure(tf.linalg.norm, grad_d)))
tf.print('grad_d norm', ddd)
## Train G
dg_out_fake = d(fake_inputs)
loss_g = keras.losses.binary_crossentropy(
tf.ones_like(dg_out_fake),
dg_out_fake, from_logits=True)
# loss_g = -keras.losses.binary_crossentropy(
# tf.zeros_like(dg_out_fake), dg_out_fake, from_logits=True)
grad_g = g_tape.gradient(loss_g, g.trainable_variables)
optG.apply_gradients(zip(grad_g, g.trainable_variables))
ddd = tf.linalg.norm(list(tf.nest.map_structure(tf.linalg.norm, grad_g)))
tf.print('grad_g norm', ddd)
# Epoch
for epoch in range(EPOCHS):
print(f'epoch: {epoch}')
dataset = get_dataset_gcs().batch(GLOBAL_BATCH_SIZE, drop_remainder=True
)
dataset = strategy.experimental_distribute_dataset(dataset) # iterations
with TimedInputs(3) as t:
for inputs in dataset:
strategy.experimental_run_v2(train_step, (inputs,))
t.count(GLOBAL_BATCH_SIZE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment