Created
September 13, 2023 14:12
-
-
Save Tob-iee/20755d6693a0ee6a067b6cf40d8f1e8b to your computer and use it in GitHub Desktop.
Custom training callback for logging training activities to Picsellia experiment from the Trainer API class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomPicselliaCallback(TrainerCallback): | |
def __init__(self, experiment: Experiment): | |
self.experiment=experiment | |
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control, **kwargs): | |
print("Starting training") | |
def on_train_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs): | |
""" | |
Event called at the end of training. | |
""" | |
# Keep track of train loss. | |
print("state.log_history:", state.log_history) | |
# Loop through each log history. | |
for log_history in state.log_history: | |
if 'loss' in log_history.keys(): | |
# Deal with trianing loss. | |
loss = log_history['loss'] | |
learning_rate_decay = log_history['learning_rate'] | |
print('train_loss:', loss) | |
print('train_lr-decay:', learning_rate_decay) | |
try: | |
self._log_metric("loss_training_hist", loss, LogType.LINE) | |
self._log_metric("lr-decay_hist", learning_rate_decay, LogType.LINE) | |
except Exception as e: | |
print("can't send log") | |
def _log_metric(self, name: str, value: float, retry: int): | |
try: | |
self.experiment.log(name=name, data=value, type=LogType.LINE, replace=True) | |
except Exception: | |
logging.exception(f"couldn't log {name}") | |
if retry > 0: | |
logging.info(f"retrying log {name}") | |
self._log_metric(name, value, retry-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment