Skip to content

Instantly share code, notes, and snippets.

@ortegatron
Last active March 14, 2024 19:10
Show Gist options
  • Star 45 You must be signed in to star a gist
  • Fork 14 You must be signed in to fork a gist
  • Save ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b to your computer and use it in GitHub Desktop.
Save ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b to your computer and use it in GitHub Desktop.
Trainer with Loss on Validation for Detectron2
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
class LossEvalHook(HookBase):
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
mean_loss = np.mean(losses)
self.trainer.storage.put_scalar('validation_loss', mean_loss)
comm.synchronize()
return losses
def _get_loss(self, data):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)
class MyTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return COCOEvaluator(dataset_name, cfg, True, output_folder)
def build_hooks(self):
hooks = super().build_hooks()
hooks.insert(-1,LossEvalHook(
cfg.TEST.EVAL_PERIOD,
self.model,
build_detection_test_loader(
self.cfg,
self.cfg.DATASETS.TEST[0],
DatasetMapper(self.cfg,True)
)
))
return hooks
import json
import matplotlib.pyplot as plt
experiment_folder = './output/model_iter4000_lr0005_wf1_date2020_03_20__05_16_45'
def load_json_arr(json_path):
lines = []
with open(json_path, 'r') as f:
for line in f:
lines.append(json.loads(line))
return lines
experiment_metrics = load_json_arr(experiment_folder + '/metrics.json')
plt.plot(
[x['iteration'] for x in experiment_metrics],
[x['total_loss'] for x in experiment_metrics])
plt.plot(
[x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
[x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x])
plt.legend(['total_loss', 'validation_loss'], loc='upper left')
plt.show()
@balajihosur
Copy link

@ortegatron Can you please how is the cfg file format. where to add this cfg.TEST.EVAL_PERIOD in cfg file. Can you provide me the format of config file.

@konrad98ft
Copy link

konrad98ft commented Jan 13, 2022

Used following MyTainer and after validation started got an error:
RuntimeError: DataLoader worker (pid(s) 13656, 10880, 3784, 12464) exited unexpectedly
How to solve it?

@MLDeep414
Copy link

@https://github.com/alexriedel1

Hi,
I have calculated validation_loss, based on this I am trying to implement earlystopping using this code

class EarlyStopping():
"""
Early stopping to stop the training when the loss does not improve after
certain epochs.
"""
def init(self, patience=5, min_delta=0):
"""
:param patience: how many epochs to wait before stopping when loss is
not improving
:param min_delta: minimum difference between new loss and old loss for
new loss to be considered as an improvement
"""
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def call(self, val_loss):
if self.best_loss == None:
self.best_loss = val_loss
elif self.best_loss - val_loss > self.min_delta:
self.best_loss = val_loss
# reset counter if validation loss improves
self.counter = 0
elif self.best_loss - val_loss < self.min_delta:
self.counter += 1
print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
if self.counter >= self.patience:
print('INFO: Early stopping')
self.early_stop = True
but I could not achieve it properly, could you please help me.

@alexriedel1
Copy link

@MLDeep414 please write your code properly formatted and show what actual problem your running into in order to get help

@alexriedel1
Copy link

alexriedel1 commented Feb 15, 2022

@MLDeep414 sorry but the formatting is even worse this time as some of the code is formatted and some not :(
do you get any error? is the hook not executed? whats the problem here?

@alexriedel1
Copy link

https://detectron2.readthedocs.io/en/latest/modules/engine.html#detectron2.engine.HookBase

The hook as a child of HookBase will be called as stated in the docs above.

You have to implement the method after_step if you want to check for early stopping after each step (which is maybe too much so check for a reasonable step inside of your method!). Inside your hook you can access the trainer object to get the necessary information about your training state.

@pieterblok
Copy link

A while ago I created a similar hook for validation: not based on loss but on the validation performance (which is in some context logic, as we also evaluate the test set on mAP). Earlystopping would have a similar functionality, although it really stops the training, while my method continues training, but saves the best model automatically.

You can find the procedure here:
https://github.com/pieterblok/maskal/blob/5e1b1e9b6c14a423b22d3218da66120cbb0b7f7c/maskAL.py#L319

It might be of interest…

@veer5551
Copy link

Hi,
Thanks for the amazing Work!
I included the above LossEval hook into the training and found out that the extra workers that are created (for data loaders?) are not exiting and are locked in. This makes code run out of resources and the machine stops after a while!

Using Multi-GPU training.
Here is the issue raised on the detectron2 repo: facebookresearch/detectron2#3953

Any thoughts on this guys, I am restricted from doing any training :(

Thanks!

@EmmaVanPuyenbr
Copy link

EmmaVanPuyenbr commented Apr 6, 2022

hi,
Thanks for the nice code.
I was wondering how you were able to store the logging under a customized name? or were did you first define the name of the model in

experiment_folder = './output/model_iter4000_lr0005_wf1_date2020_03_20__05_16_45' 

and then also the /metrics.json name? can you change the filename?

thanks!

@hikmatkhan
Copy link

Hi
What is this line doing?
self.trainer.storage.put_scalars(timetest=12)

@daeyeoplee
Copy link

I put these whole things in my .py file.

AssertionError:

It keep occurs this error with no other message.

How can i handle this problem thank you

@xxxming730
Copy link

In my metrics.json, i get this:
{"bbox/AP": 0.0031121327534856355, "bbox/AP-background": NaN, "bbox/AP-ck": 0.0275027502750275, "bbox/AP-fml": 0.0, "bbox/AP-fmm": 0.0, "bbox/AP-gh": 0.0, "bbox/AP-gpe": 0.0, "bbox/AP-gpr": 0.0005064445063432175, "bbox/AP-s": 0.0, "bbox/AP-sc": 0.0, "bbox/AP-ss": 0.0, "bbox/AP50": 0.00639308034241901, "bbox/AP75": 0.0, "bbox/APl": 0.00016924769400016924, "bbox/APm": 0.0, "bbox/APs": 0.03000300030003, "iteration": 150, "segm/AP": 0.001222344456667889, "segm/AP-background": NaN, "segm/AP-ck": 0.011001100110011, "segm/AP-fml": 0.0, "segm/AP-fmm": 0.0, "segm/AP-gh": 0.0, "segm/AP-gpe": 0.0, "segm/AP-gpr": 0.0, "segm/AP-s": 0.0, "segm/AP-sc": 0.0, "segm/AP-ss": 0.0, "segm/AP50": 0.006111722283339444, "segm/AP75": 0.0, "segm/APl": 0.0, "segm/APm": 0.0, "segm/APs": 0.0088008800880088}
{"data_time": 1.5641630000000077, "eta_seconds": 19170.847725999975, "fast_rcnn/cls_accuracy": 0.98468017578125, "fast_rcnn/false_negative": 1.0, "fast_rcnn/fg_cls_accuracy": 0.0, "iteration": 179, "loss_box_reg": 0.041649249847978354, "loss_cls": 0.10573109425604343, "loss_mask": 0.6462322175502777, "loss_rpn_cls": 0.13051373744383454, "loss_rpn_loc": 0.05502640060149133, "lr": 4.4955249999999996e-05, "mask_rcnn/accuracy": 0.7032937057778648, "mask_rcnn/false_negative": 0.050562366590145194, "mask_rcnn/false_positive": 0.6311390203732659, "roi_head/num_bg_samples": 504.15625, "roi_head/num_fg_samples": 7.84375, "rpn/num_neg_anchors": 236.09375, "rpn/num_pos_anchors": 19.90625, "time": 3.7675802499999236, "total_loss": 1.0010524874087423}
{"iteration": 178, "timetest": 12.0}
...

there is no 'validation_loss'

However, the 'validation_loss' is output in the console

and I also have this question:

Hi What is this line doing? self.trainer.storage.put_scalars(timetest=12)

If you have this problem, we can discuss it together. If anyone knows why, please let me know. Thank you very much. Hope everyone is OK

@xxxming730
Copy link

A single GPU is fine, we can output validation_loss and record it, but when training on multiple Gpus, the related meters.json output will be wrong like above, and we have no idea what is being output, anyone solved this problem, thanks a lot.

@Scarlet3101
Copy link

In my metrics.json, i get this: {"bbox/AP": 0.0031121327534856355, "bbox/AP-background": NaN, "bbox/AP-ck": 0.0275027502750275, "bbox/AP-fml": 0.0, "bbox/AP-fmm": 0.0, "bbox/AP-gh": 0.0, "bbox/AP-gpe": 0.0, "bbox/AP-gpr": 0.0005064445063432175, "bbox/AP-s": 0.0, "bbox/AP-sc": 0.0, "bbox/AP-ss": 0.0, "bbox/AP50": 0.00639308034241901, "bbox/AP75": 0.0, "bbox/APl": 0.00016924769400016924, "bbox/APm": 0.0, "bbox/APs": 0.03000300030003, "iteration": 150, "segm/AP": 0.001222344456667889, "segm/AP-background": NaN, "segm/AP-ck": 0.011001100110011, "segm/AP-fml": 0.0, "segm/AP-fmm": 0.0, "segm/AP-gh": 0.0, "segm/AP-gpe": 0.0, "segm/AP-gpr": 0.0, "segm/AP-s": 0.0, "segm/AP-sc": 0.0, "segm/AP-ss": 0.0, "segm/AP50": 0.006111722283339444, "segm/AP75": 0.0, "segm/APl": 0.0, "segm/APm": 0.0, "segm/APs": 0.0088008800880088} {"data_time": 1.5641630000000077, "eta_seconds": 19170.847725999975, "fast_rcnn/cls_accuracy": 0.98468017578125, "fast_rcnn/false_negative": 1.0, "fast_rcnn/fg_cls_accuracy": 0.0, "iteration": 179, "loss_box_reg": 0.041649249847978354, "loss_cls": 0.10573109425604343, "loss_mask": 0.6462322175502777, "loss_rpn_cls": 0.13051373744383454, "loss_rpn_loc": 0.05502640060149133, "lr": 4.4955249999999996e-05, "mask_rcnn/accuracy": 0.7032937057778648, "mask_rcnn/false_negative": 0.050562366590145194, "mask_rcnn/false_positive": 0.6311390203732659, "roi_head/num_bg_samples": 504.15625, "roi_head/num_fg_samples": 7.84375, "rpn/num_neg_anchors": 236.09375, "rpn/num_pos_anchors": 19.90625, "time": 3.7675802499999236, "total_loss": 1.0010524874087423} {"iteration": 178, "timetest": 12.0} ...

there is no 'validation_loss'

However, the 'validation_loss' is output in the console

and I also have this question:

Hi What is this line doing? self.trainer.storage.put_scalars(timetest=12)

If you have this problem, we can discuss it together. If anyone knows why, please let me know. Thank you very much. Hope everyone is OK

Hi!, I have the same problem when I set eval_period:100. but if I set it to 50 then everything works correctly. the validation_loss result is written to metrics.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment