Skip to content

Instantly share code, notes, and snippets.

@Abhishek-Shaw-Kolkata
Last active March 13, 2021 17:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Abhishek-Shaw-Kolkata/f49384e13531d37ead9fa4edc5f7c5d0 to your computer and use it in GitHub Desktop.
Save Abhishek-Shaw-Kolkata/f49384e13531d37ead9fa4edc5f7c5d0 to your computer and use it in GitHub Desktop.
%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm
import tensorflow as tf
from tensorflow import keras
def get_segmentaion_model(name= 'Uefficientnetb4', BACKBONE = 'efficientnetb4',ENCODER_WEIGHTS = 'imagenet'):
'''
Creates segmentaion model object and compiles it with Adam optimizer and combined_loss function
Args:
name : Name of model
BACKBONE : BACKBONE model name to be used as encider part
ENCODER_WEIGHTS : weights with which to assign the encoder
Returns:
complied segmentaion model object
'''
tf.random.set_seed(100) # Set global seed
keras.backend.clear_session() # For easy reset of notebook state.
model_seg = sm.Unet(BACKBONE, input_shape=(256, 256, 3), encoder_weights= ENCODER_WEIGHTS)
#keras.utils.plot_model(model_seg, name + ".png", show_shapes=True)
model_seg.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss=combined_loss , metrics=[dice_coef])
return model_seg
def set_log_checkpoint(name):
'''
helper function to set model checkpoint, tensorboard callback
Args:
name: the name of the model
Returns:
model checkpoint callback , tensorboard callback, learningrate schedule callback
'''
global CHECKPOINT_PATH
CHECKPOINT_PATH = "training/" + name + ".ckpt"
checkpoint_dir = os.path.dirname(CHECKPOINT_PATH)
# Create a callback that saves the model's weights
cp_callback = keras.callbacks.ModelCheckpoint(filepath=CHECKPOINT_PATH,
save_weights_only=True,
verbose=1,
monitor='val_loss',
save_best_only=True)
# Load the TensorBoard notebook extension
%load_ext tensorboard
# Clear any logs from previous runs
! rm -rf ./logs/
# Set up log directory
logdir = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
#print(logdir)
%tensorboard --logdir $logdir
tensorboard_callback = keras.callbacks.TensorBoard(logdir, histogram_freq=1)
return cp_callback,tensorboard_callback
def train_and_plot_metrics(model,name,cp_callback,tensorboard_callback,train_dataset,test_dataset):
'''
Trains model on train data and plots different training metrics
Args:
model : Model object to which to feed the data
cp_callback : Model checkpoint callback object
tensorboard_callback : tensorboard callback obejct
train_dataset : train_dataset obejct
test_dataset: test_dataset obejct
Returns:
None
'''
history = model.fit(train_dataset,
use_multiprocessing=True,
epochs=NB_EPOCH,
batch_size = BATCH_SIZE,
steps_per_epoch = len(train_dataset),
validation_data= test_dataset,
verbose=1,
callbacks=[cp_callback, tensorboard_callback ]
)
# Loading best saved model
model.load_weights(CHECKPOINT_PATH)
# saving training history
np.save('./training/' + name + '.npy',history.history)
#saving model with architecture
model.save('./models/' + name)
plot_metrics(history)
name= 'Uvgg16'
model_seg = get_segmentaion_model(name, BACKBONE = 'vgg16')
cp_callback,tensorboard_callback = set_log_checkpoint(name)
train_and_plot_metrics(model_seg,name,cp_callback,tensorboard_callback)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment