Skip to content

Instantly share code, notes, and snippets.

@jayleicn
Last active July 19, 2021 17:50
Show Gist options
  • Save jayleicn/c19123454f9d310ed6029115b3dc1ac5 to your computer and use it in GitHub Desktop.
Save jayleicn/c19123454f9d310ed6029115b3dc1ac5 to your computer and use it in GitHub Desktop.
Logging other number in tensorbaord for MMF.
# in mmf/mmf/trainers/mmf_trainer.py
# add line below
from user import all_user_callbacks
# add two lines in
@registry.register_trainer("mmf")
class MMFTrainer(
TrainerCallbackHookMixin,
TrainerTrainingLoopMixin,
TrainerDeviceMixin,
TrainerEvaluationLoopMixin,
TrainerProfilingMixin,
BaseTrainer,
):
def __init__(self, config: DictConfig):
super().__init__(config)
def configure_callbacks(self):
...
# user callbacks # add by @Jie
for callback_cls in all_user_callbacks:
self.callbacks.append(callback_cls(self.config, self))
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import torch
import math
from mmf.trainers.callbacks.base import Callback
from mmf.utils.configuration import get_mmf_env
from mmf.utils.logger import (
TensorboardLogger,
calculate_time_left,
setup_output_folder,
summarize_report,
)
from mmf.utils.timer import Timer
from mmf.utils.general import print_model_parameters
logger = logging.getLogger(__name__)
class VseInfCallback(Callback):
"""A user customized callback class for VSE Inf training."""
def __init__(self, config, trainer):
"""
Attr:
config(mmf_typings.DictConfig): Config for the callback
trainer(Type[BaseTrainer]): Trainer object
"""
super().__init__(config, trainer)
if hasattr(config.training, "vse_inf_training"):
vse_cfg = config.training.vse_inf_training
self.enable = vse_cfg.enable
self.warmup_steps = vse_cfg.warmup_steps
# freeze_cnn and max_violation only works when enable is True
self.freeze_img_backbone = vse_cfg.freeze_img_backbone # default True
self.freeze_blocks_after_warmup = vse_cfg.freeze_blocks_after_warmup # default 0
self.max_violation = vse_cfg.max_violation # default False
logger.info(f"VseInfCallback: {repr(vse_cfg)}")
else:
self.enable = False
def on_update_start(self, **kwargs):
if self.enable:
self.apply_vse_training_scheme()
def apply_vse_training_scheme(self):
model = self.trainer.model
if hasattr(model, "module"): # DDP
model = model.module
current_iteration = self.trainer.current_iteration
verbose = current_iteration in [0, self.warmup_steps]
# loss
model.set_max_violation(True, verbose=verbose)
if current_iteration < self.warmup_steps:
# reset for warmup phase. self.max_violation is False by default
model.set_max_violation(self.max_violation, verbose=verbose)
# freeze/unfreeze parameters
if current_iteration < self.warmup_steps and self.freeze_img_backbone:
model.freeze_img_backbone(verbose) # freeze all
else:
model.unfreeze_img_backbone(
self.freeze_blocks_after_warmup, verbose) # 0 == unfreeze all
if verbose:
print_model_parameters(model)
class TensorBoardLoggingCallback(Callback):
"""A customized callback class for logging other parameters."""
def __init__(self, config, trainer):
"""
Attr:
config(mmf_typings.DictConfig): Config for the callback
trainer(Type[BaseTrainer]): Trainer object
"""
super().__init__(config, trainer)
logger.info(f"{str(self)}: {repr(config.training.vse_inf_training)}")
def on_update_end(self, **kwargs):
trainer = self.trainer
num_updates = trainer.num_updates
tb_writer = trainer.logistics_callback.tb_writer
# Log Any other numbers here
if num_updates % trainer.logistics_callback.log_interval == 0:
# learning rate
lr = trainer.optimizer.param_groups[0]['lr']
tb_writer.add_scalar("train/lr", lr, num_updates)
# trainable temperature
_model = trainer.model
_model = _model.module if hasattr(_model, "module") else _model
if hasattr(_model, "learned_temperature"):
t = float(_model.learned_temperature.cpu())
t = 1. / math.exp(-t) # CLIP's implementation
tb_writer.add_scalar("train/temperature", t, num_updates)
all_user_callbacks = [VseInfCallback, TensorBoardLoggingCallback]
@jayleicn
Copy link
Author

jayleicn commented Jul 19, 2021

Step 1: Create a new file user.py and write a new callback class TensorBoardLoggingCallback, This class will have access to trainer (Step 4), and thus the tensorboard via self.trainer.logistics_callback.tb_writer. Thus, we can log number easily with it. For example, in my user case, I logged learning rate and one of my model parameter learned_temperature .

Step 2: Add the callback to the all_user_callbacks list

Step 3: Import all_user_callbacks in mmf/trainers/mmf_trainer.py file

Step 4: Add two lines in the configure_callbacks method, as in https://gist.github.com/jayleicn/c19123454f9d310ed6029115b3dc1ac5#file-mmf_trainer-py-L20-L22. This will register the callbacks into the trainer.

Done!

BTW, using this logic, you can also do some other customizations, like control which part of the model to freeze at different steps, as in this callback: https://gist.github.com/jayleicn/c19123454f9d310ed6029115b3dc1ac5#file-user-py-L21

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment