Skip to content

Instantly share code, notes, and snippets.

@ortegatron
Last active March 14, 2024 19:10
Show Gist options
  • Star 45 You must be signed in to star a gist
  • Fork 14 You must be signed in to fork a gist
  • Save ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b to your computer and use it in GitHub Desktop.
Save ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b to your computer and use it in GitHub Desktop.
Trainer with Loss on Validation for Detectron2
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
class LossEvalHook(HookBase):
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
mean_loss = np.mean(losses)
self.trainer.storage.put_scalar('validation_loss', mean_loss)
comm.synchronize()
return losses
def _get_loss(self, data):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)
class MyTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return COCOEvaluator(dataset_name, cfg, True, output_folder)
def build_hooks(self):
hooks = super().build_hooks()
hooks.insert(-1,LossEvalHook(
cfg.TEST.EVAL_PERIOD,
self.model,
build_detection_test_loader(
self.cfg,
self.cfg.DATASETS.TEST[0],
DatasetMapper(self.cfg,True)
)
))
return hooks
import json
import matplotlib.pyplot as plt
experiment_folder = './output/model_iter4000_lr0005_wf1_date2020_03_20__05_16_45'
def load_json_arr(json_path):
lines = []
with open(json_path, 'r') as f:
for line in f:
lines.append(json.loads(line))
return lines
experiment_metrics = load_json_arr(experiment_folder + '/metrics.json')
plt.plot(
[x['iteration'] for x in experiment_metrics],
[x['total_loss'] for x in experiment_metrics])
plt.plot(
[x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
[x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x])
plt.legend(['total_loss', 'validation_loss'], loc='upper left')
plt.show()
@daeyeoplee
Copy link

I put these whole things in my .py file.

AssertionError:

It keep occurs this error with no other message.

How can i handle this problem thank you

@xxxming730
Copy link

In my metrics.json, i get this:
{"bbox/AP": 0.0031121327534856355, "bbox/AP-background": NaN, "bbox/AP-ck": 0.0275027502750275, "bbox/AP-fml": 0.0, "bbox/AP-fmm": 0.0, "bbox/AP-gh": 0.0, "bbox/AP-gpe": 0.0, "bbox/AP-gpr": 0.0005064445063432175, "bbox/AP-s": 0.0, "bbox/AP-sc": 0.0, "bbox/AP-ss": 0.0, "bbox/AP50": 0.00639308034241901, "bbox/AP75": 0.0, "bbox/APl": 0.00016924769400016924, "bbox/APm": 0.0, "bbox/APs": 0.03000300030003, "iteration": 150, "segm/AP": 0.001222344456667889, "segm/AP-background": NaN, "segm/AP-ck": 0.011001100110011, "segm/AP-fml": 0.0, "segm/AP-fmm": 0.0, "segm/AP-gh": 0.0, "segm/AP-gpe": 0.0, "segm/AP-gpr": 0.0, "segm/AP-s": 0.0, "segm/AP-sc": 0.0, "segm/AP-ss": 0.0, "segm/AP50": 0.006111722283339444, "segm/AP75": 0.0, "segm/APl": 0.0, "segm/APm": 0.0, "segm/APs": 0.0088008800880088}
{"data_time": 1.5641630000000077, "eta_seconds": 19170.847725999975, "fast_rcnn/cls_accuracy": 0.98468017578125, "fast_rcnn/false_negative": 1.0, "fast_rcnn/fg_cls_accuracy": 0.0, "iteration": 179, "loss_box_reg": 0.041649249847978354, "loss_cls": 0.10573109425604343, "loss_mask": 0.6462322175502777, "loss_rpn_cls": 0.13051373744383454, "loss_rpn_loc": 0.05502640060149133, "lr": 4.4955249999999996e-05, "mask_rcnn/accuracy": 0.7032937057778648, "mask_rcnn/false_negative": 0.050562366590145194, "mask_rcnn/false_positive": 0.6311390203732659, "roi_head/num_bg_samples": 504.15625, "roi_head/num_fg_samples": 7.84375, "rpn/num_neg_anchors": 236.09375, "rpn/num_pos_anchors": 19.90625, "time": 3.7675802499999236, "total_loss": 1.0010524874087423}
{"iteration": 178, "timetest": 12.0}
...

there is no 'validation_loss'

However, the 'validation_loss' is output in the console

and I also have this question:

Hi What is this line doing? self.trainer.storage.put_scalars(timetest=12)

If you have this problem, we can discuss it together. If anyone knows why, please let me know. Thank you very much. Hope everyone is OK

@xxxming730
Copy link

A single GPU is fine, we can output validation_loss and record it, but when training on multiple Gpus, the related meters.json output will be wrong like above, and we have no idea what is being output, anyone solved this problem, thanks a lot.

@Scarlet3101
Copy link

In my metrics.json, i get this: {"bbox/AP": 0.0031121327534856355, "bbox/AP-background": NaN, "bbox/AP-ck": 0.0275027502750275, "bbox/AP-fml": 0.0, "bbox/AP-fmm": 0.0, "bbox/AP-gh": 0.0, "bbox/AP-gpe": 0.0, "bbox/AP-gpr": 0.0005064445063432175, "bbox/AP-s": 0.0, "bbox/AP-sc": 0.0, "bbox/AP-ss": 0.0, "bbox/AP50": 0.00639308034241901, "bbox/AP75": 0.0, "bbox/APl": 0.00016924769400016924, "bbox/APm": 0.0, "bbox/APs": 0.03000300030003, "iteration": 150, "segm/AP": 0.001222344456667889, "segm/AP-background": NaN, "segm/AP-ck": 0.011001100110011, "segm/AP-fml": 0.0, "segm/AP-fmm": 0.0, "segm/AP-gh": 0.0, "segm/AP-gpe": 0.0, "segm/AP-gpr": 0.0, "segm/AP-s": 0.0, "segm/AP-sc": 0.0, "segm/AP-ss": 0.0, "segm/AP50": 0.006111722283339444, "segm/AP75": 0.0, "segm/APl": 0.0, "segm/APm": 0.0, "segm/APs": 0.0088008800880088} {"data_time": 1.5641630000000077, "eta_seconds": 19170.847725999975, "fast_rcnn/cls_accuracy": 0.98468017578125, "fast_rcnn/false_negative": 1.0, "fast_rcnn/fg_cls_accuracy": 0.0, "iteration": 179, "loss_box_reg": 0.041649249847978354, "loss_cls": 0.10573109425604343, "loss_mask": 0.6462322175502777, "loss_rpn_cls": 0.13051373744383454, "loss_rpn_loc": 0.05502640060149133, "lr": 4.4955249999999996e-05, "mask_rcnn/accuracy": 0.7032937057778648, "mask_rcnn/false_negative": 0.050562366590145194, "mask_rcnn/false_positive": 0.6311390203732659, "roi_head/num_bg_samples": 504.15625, "roi_head/num_fg_samples": 7.84375, "rpn/num_neg_anchors": 236.09375, "rpn/num_pos_anchors": 19.90625, "time": 3.7675802499999236, "total_loss": 1.0010524874087423} {"iteration": 178, "timetest": 12.0} ...

there is no 'validation_loss'

However, the 'validation_loss' is output in the console

and I also have this question:

Hi What is this line doing? self.trainer.storage.put_scalars(timetest=12)

If you have this problem, we can discuss it together. If anyone knows why, please let me know. Thank you very much. Hope everyone is OK

Hi!, I have the same problem when I set eval_period:100. but if I set it to 50 then everything works correctly. the validation_loss result is written to metrics.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment