Skip to content

Instantly share code, notes, and snippets.

View ethanyanjiali's full-sized avatar

Ethan Yanjia Li ethanyanjiali

View GitHub Profile
gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1)
optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)
def make_generator_model(n_blocks):
# 6 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
# 9 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
model = tf.keras.Sequential()
# Encoding
model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))
def calc_gan_loss(prediction, is_real):
# Typical GAN loss to set objectives for generator and discriminator
if is_real:
return mse_loss(prediction, tf.ones_like(prediction))
else:
return mse_loss(prediction, tf.zeros_like(prediction))
def calc_cycle_loss(reconstructed_images, real_images):
# Cycle loss to make sure reconstructed image looks real
return mae_loss(reconstructed_images, real_images)
@ethanyanjiali
ethanyanjiali / cyclegan_train_generator.py
Created June 6, 2019 06:00
cyclegan_train_generator
@tf.function
def train_generator(images_a, images_b):
real_a = images_a
real_b = images_b
with tf.GradientTape() as tape:
# Use real B to generate B should be identical
identity_a2b = generator_a2b(real_b, training=True)
identity_b2a = generator_b2a(real_a, training=True)
loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)
@ethanyanjiali
ethanyanjiali / cyclegan_discriminator.py
Created June 6, 2019 06:01
cyclegan_discriminator
def make_discriminator_model():
# C64-C128-C256-C512
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=(256, 256, 3)))
model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
model.add(tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
@ethanyanjiali
ethanyanjiali / cyclegan_discriminator_loss.py
Created June 6, 2019 06:02
cyclegan_discriminator_loss
@tf.function
def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
real_a = images_a
real_b = images_b
with tf.GradientTape() as tape:
# Discriminator A should classify real_a as A
loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True)
# Discriminator A should classify generated fake_b2a as not A
loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False)
def train_step(images_a, images_b, epoch, step):
fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)
fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)
dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)
def train(dataset, epochs):
for epoch in range(checkpoint.epoch+1, epochs+1):
@ethanyanjiali
ethanyanjiali / vim-setup.md
Last active September 26, 2019 05:17
Vim setup
  • Install Vim and enable Python feature
sudo git clone https://github.com/vim/vim.git && cd vim
sudo ./configure --with-features=huge --enable-multibyte --enable-pythoninterp=yes --with-python-config-dir=/usr/lib/python2.7/config-x86_64-linux-gnu/ --enable-python3interp=yes --with-python3-config-dir=/usr/lib/python3.5/config-3.5m-x86_64-linux-gnu/ --enable-gui=gtk2 --enable-cscope --prefix=/usr/local/
  • Use this as ~/.vimrc
"vundle
set nocompatible
filetype off
@ethanyanjiali
ethanyanjiali / install-horovod.md
Last active September 26, 2019 04:46
Install Horovod with Tensorflow 2.0 on Debian stretch
  • Install gcc/g++ 7+ Add this line to /etc/apt/sources.list
deb http://ftp.de.debian.org/debian buster main 

And then install gcc/g++ 7

sudo apt-get install gcc-7 g++-7
sudo rm /usr/bin/gcc
sudo rm /usr/bin/g++
@ethanyanjiali
ethanyanjiali / hg.py
Created March 14, 2020 19:33
Single Hourglass Module
def HourglassModule(inputs, order, filters, num_residual):
"""
One Hourglass Module. Usually we stacked multiple of them together.
https://github.com/princeton-vl/pose-hg-train/blob/master/src/models/hg.lua#L3
inputs:
order: The remaining order for HG modules to call itself recursively.
num_residual: Number of residual layers for this HG module.
"""
# Upper branch