Skip to content

Instantly share code, notes, and snippets.

@Tob-iee
Created September 13, 2023 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Tob-iee/20755d6693a0ee6a067b6cf40d8f1e8b to your computer and use it in GitHub Desktop.
Save Tob-iee/20755d6693a0ee6a067b6cf40d8f1e8b to your computer and use it in GitHub Desktop.
Custom training callback for logging training activities to Picsellia experiment from the Trainer API class
class CustomPicselliaCallback(TrainerCallback):
def __init__(self, experiment: Experiment):
self.experiment=experiment
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
print("Starting training")
def on_train_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
"""
Event called at the end of training.
"""
# Keep track of train loss.
print("state.log_history:", state.log_history)
# Loop through each log history.
for log_history in state.log_history:
if 'loss' in log_history.keys():
# Deal with trianing loss.
loss = log_history['loss']
learning_rate_decay = log_history['learning_rate']
print('train_loss:', loss)
print('train_lr-decay:', learning_rate_decay)
try:
self._log_metric("loss_training_hist", loss, LogType.LINE)
self._log_metric("lr-decay_hist", learning_rate_decay, LogType.LINE)
except Exception as e:
print("can't send log")
def _log_metric(self, name: str, value: float, retry: int):
try:
self.experiment.log(name=name, data=value, type=LogType.LINE, replace=True)
except Exception:
logging.exception(f"couldn't log {name}")
if retry > 0:
logging.info(f"retrying log {name}")
self._log_metric(name, value, retry-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment