Last active
March 13, 2021 17:04
-
-
Save Abhishek-Shaw-Kolkata/f49384e13531d37ead9fa4edc5f7c5d0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%env SM_FRAMEWORK=tf.keras | |
import segmentation_models as sm | |
import tensorflow as tf | |
from tensorflow import keras | |
def get_segmentaion_model(name= 'Uefficientnetb4', BACKBONE = 'efficientnetb4',ENCODER_WEIGHTS = 'imagenet'): | |
''' | |
Creates segmentaion model object and compiles it with Adam optimizer and combined_loss function | |
Args: | |
name : Name of model | |
BACKBONE : BACKBONE model name to be used as encider part | |
ENCODER_WEIGHTS : weights with which to assign the encoder | |
Returns: | |
complied segmentaion model object | |
''' | |
tf.random.set_seed(100) # Set global seed | |
keras.backend.clear_session() # For easy reset of notebook state. | |
model_seg = sm.Unet(BACKBONE, input_shape=(256, 256, 3), encoder_weights= ENCODER_WEIGHTS) | |
#keras.utils.plot_model(model_seg, name + ".png", show_shapes=True) | |
model_seg.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss=combined_loss , metrics=[dice_coef]) | |
return model_seg | |
def set_log_checkpoint(name): | |
''' | |
helper function to set model checkpoint, tensorboard callback | |
Args: | |
name: the name of the model | |
Returns: | |
model checkpoint callback , tensorboard callback, learningrate schedule callback | |
''' | |
global CHECKPOINT_PATH | |
CHECKPOINT_PATH = "training/" + name + ".ckpt" | |
checkpoint_dir = os.path.dirname(CHECKPOINT_PATH) | |
# Create a callback that saves the model's weights | |
cp_callback = keras.callbacks.ModelCheckpoint(filepath=CHECKPOINT_PATH, | |
save_weights_only=True, | |
verbose=1, | |
monitor='val_loss', | |
save_best_only=True) | |
# Load the TensorBoard notebook extension | |
%load_ext tensorboard | |
# Clear any logs from previous runs | |
! rm -rf ./logs/ | |
# Set up log directory | |
logdir = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S")) | |
#print(logdir) | |
%tensorboard --logdir $logdir | |
tensorboard_callback = keras.callbacks.TensorBoard(logdir, histogram_freq=1) | |
return cp_callback,tensorboard_callback | |
def train_and_plot_metrics(model,name,cp_callback,tensorboard_callback,train_dataset,test_dataset): | |
''' | |
Trains model on train data and plots different training metrics | |
Args: | |
model : Model object to which to feed the data | |
cp_callback : Model checkpoint callback object | |
tensorboard_callback : tensorboard callback obejct | |
train_dataset : train_dataset obejct | |
test_dataset: test_dataset obejct | |
Returns: | |
None | |
''' | |
history = model.fit(train_dataset, | |
use_multiprocessing=True, | |
epochs=NB_EPOCH, | |
batch_size = BATCH_SIZE, | |
steps_per_epoch = len(train_dataset), | |
validation_data= test_dataset, | |
verbose=1, | |
callbacks=[cp_callback, tensorboard_callback ] | |
) | |
# Loading best saved model | |
model.load_weights(CHECKPOINT_PATH) | |
# saving training history | |
np.save('./training/' + name + '.npy',history.history) | |
#saving model with architecture | |
model.save('./models/' + name) | |
plot_metrics(history) | |
name= 'Uvgg16' | |
model_seg = get_segmentaion_model(name, BACKBONE = 'vgg16') | |
cp_callback,tensorboard_callback = set_log_checkpoint(name) | |
train_and_plot_metrics(model_seg,name,cp_callback,tensorboard_callback) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment