Skip to content

Instantly share code, notes, and snippets.

@VyBui
Created April 21, 2020 02:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save VyBui/05ec95c67c2af008975170be3fa90365 to your computer and use it in GitHub Desktop.
Save VyBui/05ec95c67c2af008975170be3fa90365 to your computer and use it in GitHub Desktop.
bps_train
import os
import time
import segmentation_models as sm
# Segmentation Models: using `keras` framework.
import tensorflow as tf
from tensorflow.python.client import device_lib
from config import cfg
from losses import schp_loss
from create_tf_records_bps import input_fn
from warm_start import get_learning_rate
from vgg19 import build_vgg19_model
from tensorflow import keras
keras.backend.set_image_data_format('channels_last')
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
def train():
"""
:return:
"""
"""
This is useful if you want to truly bound the amount of GPU memory available to the TensorFlow process.
This is common practice for local development when the GPU is shared with other applications such as a workstation GUI.
"""
gpus = get_available_gpus()
print(gpus)
try:
for gpu in gpus:
print("??")
# tf.config.experimental.set_memory_growth(gpu, True)
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 2
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print("global batch_size is: {}".format(BATCH_SIZE))
# Get dataset
params = {'batch_size': BATCH_SIZE, 'tf_records_path': cfg.TF_RECORD_PATH}
train_dataset = input_fn(mode="train", params=params)
# test_dataset = input_fn(mode="test", params=params)
print("aaaa")
train_dist_dataset = strategy.distribute_dataset(input_fn(mode="train", params=params))
# test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
with strategy.scope():
print("Building Vgg19 model")
vgg19 = build_vgg19_model()
with strategy.scope():
def train_step(batch, current_epoch):
"""
:param batch:
:param current_epoch:
:return:
"""
with tf.GradientTape() as gen_tape:
input_image, label, file_name = batch
LEARNING_RATE = get_learning_rate(cfg.max_learning_rate, cfg.min_learning_rate, current_epoch,
cfg.EPOCHS)
schp_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=0.5)
segmentation_model = sm.PSPNet('resnet101', encoder_weights='imagenet')
[_, _], D_real_style_steps, D_real_content_steps = vgg19(input_image, training=True)
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
callbacks = [
tf.keras.callbacks.ModelCheckpoint(cfg.checkpoint_dir, save_weights_only=True,
save_best_only=True,
mode='min'),
tf.keras.callbacks.ReduceLROnPlateau(),
]
dice_loss = sm.losses.DiceLoss()
focal_loss = sm.losses.CategoricalFocalLoss()
parsing_loss = dice_loss + (1 * focal_loss)
loss = schp_loss(loss_edges, parsing_loss, loss_consistent)
segmentation_model.compile(
schp_optimizer,
loss=loss,
metrics=metrics,
callbacks=callbacks
)
prediction = segmentation_model.outputs
segment_gradients = gen_tape.gradient(loss,
segmentation_model.trainable_variables)
schp_optimizer.apply_gradients(zip(segment_gradients,
segmentation_model.trainable_variables))
# if step % 10 == 0:
# with tf.device("cpu:0"):
# with summary_writer.as_default():
# tf.summary.scalar('schp loss', loss, step=step, description='schp losses blocks')
# tf.summary.scalar('gan_loss', gan_loss, step=step, description='GANs losses blocks')
# tf.summary.scalar('gan_l1_loss', gan_l1_loss, step=step, description='GANs losses blocks')
return loss
def test_step(batch, step):
"""
:param batch:
:param step:
:return:
"""
image, label, imagename = batch # do not need label_non_head
gen_output_validation = segmentation_model([image, label, imagename ], training=False)
# with tf.device("cpu:0"):
# with summary_writer.as_default():
# tf.summary.image("val_output", gen_output_validation[:, :, :, ::-1], step=step)
with strategy.scope():
# `experimental_run_v2` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs, epoch):
discriminator_per_replica_losses, generator_per_replica_losses = strategy.experimental_run_v2(
train_step,
args=(dataset_inputs, epoch))
return strategy.reduce(tf.distribute.ReduceOp.SUM, discriminator_per_replica_losses, axis=None), \
strategy.reduce(tf.distribute.ReduceOp.SUM, generator_per_replica_losses, axis=None)
# @tf.function
def distributed_test_step(dataset_inputs, step):
return strategy.experimental_run_v2(test_step, args=(dataset_inputs, step))
def fit(train_dist_dataset, epochs, test_dist_dataset):
"""
:param train_dist_dataset:
:param epochs:
:param test_dist_dataset:
:return:
"""
for epoch in range(epochs):
# TRAIN LOOP
print("Epoch: ", epoch)
d_total_loss = 0.0
g_total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
total_step = int(cfg.total_tfrecords_for_training / BATCH_SIZE)
print("The number of total steps for train: {}".format(total_step))
for num_batches in range(total_step):
print('....', end='')
step = tf.convert_to_tensor(num_batches, dtype=tf.int64)
d_loss, g_loss = distributed_train_step(next(train_iter), epoch)
d_total_loss += d_loss
g_total_loss += g_loss
if num_batches % 10 == 0:
step_template = "Step {}, d_Loss: {}, g_Loss: {}"
print(step_template.format(num_batches, d_total_loss / num_batches,
g_total_loss / num_batches))
train_d_loss = d_total_loss / num_batches
train_g_loss = g_total_loss / num_batches
template = "Epoch {}, d_Loss: {}, g_Loss: {}"
print(template.format(epoch + 1, train_d_loss, train_g_loss))
# saving (checkpoint) the model every epoch
# checkpoint.save(file_prefix=checkpoint_prefix)
# Validate the dataset every 5 epoch
if epoch % 5 == 0:
# test_iter = iter(test_dist_dataset)
total_test_steps = int(cfg.total_viton_tfrecords_for_testing / BATCH_SIZE)
print("The number of total steps for test: {}".format(total_test_steps))
test_step = 0
for x in test_dist_dataset:
test_step += 1
test_step = tf.convert_to_tensor(test_step, dtype=tf.int64)
distributed_test_step(x, test_step)
# summary_writer = tf.summary.create_file_writer(
# cfg.log_dir + "fit/" + time.datetime.now().strftime("%Y%m%d-%H%M%S"))
print("alo")
fit(train_dist_dataset, cfg.EPOCHS, None)
except Exception as e:
print(e)
if __name__ == '__main__':
# learning_rate = get_learning_rate(current_epoch=t_epoch, total_epochs=EPOCHS)
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment