Skip to content

Instantly share code, notes, and snippets.

@saitejamalyala
Last active August 6, 2021 00:01
Show Gist options
  • Save saitejamalyala/075d5714b8fe0e713fa4f377d7fd7cbd to your computer and use it in GitHub Desktop.
Save saitejamalyala/075d5714b8fe0e713fa4f377d7fd7cbd to your computer and use it in GitHub Desktop.
Custom wandb callback for Keras model training (.fit) to predict on sample data and save plot at the beginning of each epoch,.
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
class cd_wandb_custom(WandbCallback):
def __init__(
self,
# newly added
ds_test,
np_test_dataset:Dict[str,Union[ndarray,List]],
test_index:int=15,
# old
monitor="val_loss",
verbose=0,
mode="auto",
save_weights_only=False,
log_weights=False,
log_gradients=False,
save_model=True,
training_data=None,
validation_data=None,
labels=[],
data_type=None,
predictions=36,
generator=None,
input_type=None,
output_type=None,
log_evaluation=False,
validation_steps=None,
class_colors=None,
log_batch_frequency=None,
log_best_prefix="best_",
save_graph=True,
validation_indexes=None,
validation_row_processor=None,
prediction_row_processor=None,
infer_missing_processors=True,
):
super().__init__(
monitor=monitor,
verbose=verbose,
mode=mode,
save_weights_only=save_weights_only,
log_weights=log_weights,
log_gradients=log_gradients,
save_model=save_model,
training_data=training_data,
validation_data=validation_data,
labels=labels,
data_type=data_type,
predictions=predictions,
generator=generator,
input_type=input_type,
output_type=output_type,
log_evaluation=log_evaluation,
validation_steps=validation_steps,
class_colors=class_colors,
log_batch_frequency=log_batch_frequency,
log_best_prefix=log_best_prefix,
save_graph=save_graph,
validation_indexes=validation_indexes,
validation_row_processor=validation_row_processor,
prediction_row_processor=prediction_row_processor,
infer_missing_processors=infer_missing_processors,
)
self.np_ds_test = np_test_dataset
self.test_idx = test_index
self.ds_test = ds_test
pass
def __get_index_data(self, np_ds_test, test_idx):
test_data = {}
test_data["grid_map"] = np_ds_test["grid_map"][test_idx]
test_data["grid_org_res"] = np_ds_test["grid_org_res"][test_idx]
test_data["left_bnd"] = np_ds_test["left_bnd"][test_idx]
test_data["right_bnd"] = np_ds_test["right_bnd"][test_idx]
test_data["car_odo"] = np_ds_test["car_odo"][test_idx]
test_data["init_path"] = np_ds_test["init_path"][test_idx]
test_data["opt_path"] = np_ds_test["opt_path"][test_idx]
test_data["file_details"] = np_ds_test["file_details"][test_idx]
test_data["testidx"] = test_idx
test_data["predictions"] = np_ds_test["predictions"][test_idx]
return test_data
def __plot_scene(self, features):
grid_map = features["grid_map"]
grid_org = features["grid_org_res"] # [x,y,resolution]
left_bnd = features["left_bnd"]
right_bnd = features["right_bnd"]
init_path = features["init_path"]
opt_path = features["opt_path"]
car_odo = features["car_odo"]
predict_path = features["predictions"]
file_details = features["file_details"]
# print(type(grid_map))
plt.figure(figsize=(10, 10), dpi=200)
# ax=fig.add_subplot(1,1,1)
res = grid_org[2]
plt.plot(
(left_bnd[:, 0] - grid_org[0]) / res,
(left_bnd[:, 1] - grid_org[1]) / res,
"-.",
color="magenta",
markersize=0.5,
linewidth=0.5,
)
plt.plot(
(init_path[:, 0] - grid_org[0]) / res,
(init_path[:, 1] - grid_org[1]) / res,
"o-",
color="lawngreen",
markersize=1,
linewidth=1,
)
plt.plot(
(opt_path[:, 0] - grid_org[0]) / res,
(opt_path[:, 1] - grid_org[1]) / res,
"--",
color="yellow",
markersize=1,
linewidth=1,
)
plt.plot(
(predict_path[:, 0] - grid_org[0]) / res,
(predict_path[:, 1] - grid_org[1]) / res,
"--",
color="orange",
markersize=1,
linewidth=1,
)
plt.plot(
(right_bnd[:, 0] - grid_org[0]) / res,
(right_bnd[:, 1] - grid_org[1]) / res,
"-.",
color="magenta",
markersize=0.5,
linewidth=0.5,
)
plt.plot(
(car_odo[0] - grid_org[0]) / res,
(car_odo[1] - grid_org[1]) / res,
"r*",
color="red",
markersize=8,
)
plt.legend(
[
"Left bound",
"gt_init_path",
"gt_opt_path",
"predicted_path",
"right bound",
"car_centre",
],
loc="lower left",
)
plt.imshow(grid_map, origin="lower")
plt.title(f"{file_details}\nTest Index: {features['testidx']}")
# save_fig_dir = '/netpool/work/gpu-3/users/malyalasa/New_folder/rosbag2numpy/test_results'
# fig.savefig(f"{save_fig_dir}/Test_index_{features['testidx']}.jpg",format='jpg',dpi=300)
# print(type(file_details))
# cp_plt = plt
return res, plt
def on_epoch_begin(self, epoch, logs):
np_predictions = self.model.predict(self.ds_test)
self.np_ds_test["predictions"] = np_predictions
sample_data = self.__get_index_data(
np_ds_test=self.np_ds_test, test_idx=self.test_idx
)
_, sample_fig = self.__plot_scene(features=sample_data)
if (epoch-1) % 2 == 0:
wandb.log({f"sample_img_{epoch-1}": sample_fig})
sample_fig.close()
return super().on_epoch_begin(epoch, logs=logs)
def euclidean_distance_loss(y_true, y_pred):
"""
Euclidean distance loss
https://en.wikipedia.org/wiki/Euclidean_distance
:param y_true: TensorFlow tensor
:param y_p red: TensorFlow tensor of the same shape as y_true
:return: float
"""
#original euclidean distance loss = K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))
#loss = K.mean(K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1)))
loss = K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))
return loss
def endpoint_loss(y_true, y_pred):
loss = K.sqrt(K.sum(K.square(y_pred[-1,:] - y_true[-1,:])))
return loss
model = tf.keras.models.Sequential([
tf.keras.Input(name='input_layer', shape=(10,)),
tf.keras.layers.Dense(50,activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')])
opt = 'adam'
model.compile(
optimizer=opt, loss=[euclidean_distance_loss,endpoint_loss],loss_weights=[1.0,0.25], metrics='accuracy')
# Learning rate scheduler
cb_reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.2, patience=4, min_lr=0.0001)
history = model.fit(ds_train,epochs=params.get("epochs"),validation_data=ds_valid,
callbacks=[cb_reduce_lr, cd_wandb_custom(ds_test=ds_test, np_test_dataset=np_ds_test,test_index=15),],)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment