Last active
August 6, 2021 00:01
-
-
Save saitejamalyala/075d5714b8fe0e713fa4f377d7fd7cbd to your computer and use it in GitHub Desktop.
Custom wandb callback for Keras model training (.fit) to predict on sample data and save plot at the beginning of each epoch,.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
from wandb.keras import WandbCallback | |
import tensorflow as tf | |
class cd_wandb_custom(WandbCallback): | |
def __init__( | |
self, | |
# newly added | |
ds_test, | |
np_test_dataset:Dict[str,Union[ndarray,List]], | |
test_index:int=15, | |
# old | |
monitor="val_loss", | |
verbose=0, | |
mode="auto", | |
save_weights_only=False, | |
log_weights=False, | |
log_gradients=False, | |
save_model=True, | |
training_data=None, | |
validation_data=None, | |
labels=[], | |
data_type=None, | |
predictions=36, | |
generator=None, | |
input_type=None, | |
output_type=None, | |
log_evaluation=False, | |
validation_steps=None, | |
class_colors=None, | |
log_batch_frequency=None, | |
log_best_prefix="best_", | |
save_graph=True, | |
validation_indexes=None, | |
validation_row_processor=None, | |
prediction_row_processor=None, | |
infer_missing_processors=True, | |
): | |
super().__init__( | |
monitor=monitor, | |
verbose=verbose, | |
mode=mode, | |
save_weights_only=save_weights_only, | |
log_weights=log_weights, | |
log_gradients=log_gradients, | |
save_model=save_model, | |
training_data=training_data, | |
validation_data=validation_data, | |
labels=labels, | |
data_type=data_type, | |
predictions=predictions, | |
generator=generator, | |
input_type=input_type, | |
output_type=output_type, | |
log_evaluation=log_evaluation, | |
validation_steps=validation_steps, | |
class_colors=class_colors, | |
log_batch_frequency=log_batch_frequency, | |
log_best_prefix=log_best_prefix, | |
save_graph=save_graph, | |
validation_indexes=validation_indexes, | |
validation_row_processor=validation_row_processor, | |
prediction_row_processor=prediction_row_processor, | |
infer_missing_processors=infer_missing_processors, | |
) | |
self.np_ds_test = np_test_dataset | |
self.test_idx = test_index | |
self.ds_test = ds_test | |
pass | |
def __get_index_data(self, np_ds_test, test_idx): | |
test_data = {} | |
test_data["grid_map"] = np_ds_test["grid_map"][test_idx] | |
test_data["grid_org_res"] = np_ds_test["grid_org_res"][test_idx] | |
test_data["left_bnd"] = np_ds_test["left_bnd"][test_idx] | |
test_data["right_bnd"] = np_ds_test["right_bnd"][test_idx] | |
test_data["car_odo"] = np_ds_test["car_odo"][test_idx] | |
test_data["init_path"] = np_ds_test["init_path"][test_idx] | |
test_data["opt_path"] = np_ds_test["opt_path"][test_idx] | |
test_data["file_details"] = np_ds_test["file_details"][test_idx] | |
test_data["testidx"] = test_idx | |
test_data["predictions"] = np_ds_test["predictions"][test_idx] | |
return test_data | |
def __plot_scene(self, features): | |
grid_map = features["grid_map"] | |
grid_org = features["grid_org_res"] # [x,y,resolution] | |
left_bnd = features["left_bnd"] | |
right_bnd = features["right_bnd"] | |
init_path = features["init_path"] | |
opt_path = features["opt_path"] | |
car_odo = features["car_odo"] | |
predict_path = features["predictions"] | |
file_details = features["file_details"] | |
# print(type(grid_map)) | |
plt.figure(figsize=(10, 10), dpi=200) | |
# ax=fig.add_subplot(1,1,1) | |
res = grid_org[2] | |
plt.plot( | |
(left_bnd[:, 0] - grid_org[0]) / res, | |
(left_bnd[:, 1] - grid_org[1]) / res, | |
"-.", | |
color="magenta", | |
markersize=0.5, | |
linewidth=0.5, | |
) | |
plt.plot( | |
(init_path[:, 0] - grid_org[0]) / res, | |
(init_path[:, 1] - grid_org[1]) / res, | |
"o-", | |
color="lawngreen", | |
markersize=1, | |
linewidth=1, | |
) | |
plt.plot( | |
(opt_path[:, 0] - grid_org[0]) / res, | |
(opt_path[:, 1] - grid_org[1]) / res, | |
"--", | |
color="yellow", | |
markersize=1, | |
linewidth=1, | |
) | |
plt.plot( | |
(predict_path[:, 0] - grid_org[0]) / res, | |
(predict_path[:, 1] - grid_org[1]) / res, | |
"--", | |
color="orange", | |
markersize=1, | |
linewidth=1, | |
) | |
plt.plot( | |
(right_bnd[:, 0] - grid_org[0]) / res, | |
(right_bnd[:, 1] - grid_org[1]) / res, | |
"-.", | |
color="magenta", | |
markersize=0.5, | |
linewidth=0.5, | |
) | |
plt.plot( | |
(car_odo[0] - grid_org[0]) / res, | |
(car_odo[1] - grid_org[1]) / res, | |
"r*", | |
color="red", | |
markersize=8, | |
) | |
plt.legend( | |
[ | |
"Left bound", | |
"gt_init_path", | |
"gt_opt_path", | |
"predicted_path", | |
"right bound", | |
"car_centre", | |
], | |
loc="lower left", | |
) | |
plt.imshow(grid_map, origin="lower") | |
plt.title(f"{file_details}\nTest Index: {features['testidx']}") | |
# save_fig_dir = '/netpool/work/gpu-3/users/malyalasa/New_folder/rosbag2numpy/test_results' | |
# fig.savefig(f"{save_fig_dir}/Test_index_{features['testidx']}.jpg",format='jpg',dpi=300) | |
# print(type(file_details)) | |
# cp_plt = plt | |
return res, plt | |
def on_epoch_begin(self, epoch, logs): | |
np_predictions = self.model.predict(self.ds_test) | |
self.np_ds_test["predictions"] = np_predictions | |
sample_data = self.__get_index_data( | |
np_ds_test=self.np_ds_test, test_idx=self.test_idx | |
) | |
_, sample_fig = self.__plot_scene(features=sample_data) | |
if (epoch-1) % 2 == 0: | |
wandb.log({f"sample_img_{epoch-1}": sample_fig}) | |
sample_fig.close() | |
return super().on_epoch_begin(epoch, logs=logs) | |
def euclidean_distance_loss(y_true, y_pred): | |
""" | |
Euclidean distance loss | |
https://en.wikipedia.org/wiki/Euclidean_distance | |
:param y_true: TensorFlow tensor | |
:param y_p red: TensorFlow tensor of the same shape as y_true | |
:return: float | |
""" | |
#original euclidean distance loss = K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1)) | |
#loss = K.mean(K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))) | |
loss = K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1)) | |
return loss | |
def endpoint_loss(y_true, y_pred): | |
loss = K.sqrt(K.sum(K.square(y_pred[-1,:] - y_true[-1,:]))) | |
return loss | |
model = tf.keras.models.Sequential([ | |
tf.keras.Input(name='input_layer', shape=(10,)), | |
tf.keras.layers.Dense(50,activation='relu'), | |
tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')]) | |
opt = 'adam' | |
model.compile( | |
optimizer=opt, loss=[euclidean_distance_loss,endpoint_loss],loss_weights=[1.0,0.25], metrics='accuracy') | |
# Learning rate scheduler | |
cb_reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.2, patience=4, min_lr=0.0001) | |
history = model.fit(ds_train,epochs=params.get("epochs"),validation_data=ds_valid, | |
callbacks=[cb_reduce_lr, cd_wandb_custom(ds_test=ds_test, np_test_dataset=np_ds_test,test_index=15),],) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment