Skip to content

Instantly share code, notes, and snippets.

@Aunsiels
Created June 12, 2023 09:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Aunsiels/62bc76e64d1914d0b5fa9f6042a515ee to your computer and use it in GitHub Desktop.
Save Aunsiels/62bc76e64d1914d0b5fa9f6042a515ee to your computer and use it in GitHub Desktop.
# Script to merge several scalars from different experiments.
# Experiments that follow the same pattern <name>_epoch<X>_part<Y> will be merged into a single directory <name>.
# The _part<Y> is optional
# If the experiment name contains no epoch and no part, it will be replicated several times
# (= max number of epochs in the experiments)
# The program takes as input the original runs directory and the wanted output merged runs directory.
# How to use:
# python tensorboard_merger.py source target
# Example:
# python tensorboard_merger.python /tmp/runs /tmp/runs_merged
from pprint import pprint
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter
from sys import argv
import os
import re
BASE_DIR = argv[1]
TARGET_DIR = argv[2]
RE_GROUP = re.compile(r"(?P<name>.*)_epoch(?P<epoch>\d*)(_part(?P<part>\d*))?/?")
def get_groups():
res = {}
for dir in os.listdir(BASE_DIR):
match = RE_GROUP.search(dir)
if match is None:
res[dir] = [(-1, -1)]
else:
name = match.group("name")
epoch = int(match.group("epoch"))
part = int(match.group("part") or -1)
if name not in res:
res[name] = []
res[name].append((epoch, part))
for name in res:
res[name] = sorted(res[name])
return res
def get_fullname(base, epoch, part):
if epoch != -1:
base += "_epoch" + str(epoch)
if part != -1:
base += "_part" + str(part)
return base
def merge_group(name, epoch_part, max_epoch):
print("Merging", name)
if len(epoch_part) == 1 and epoch_part[0] == (-1, -1):
# We have a baseline on a single epoch
epoch_part = [(-1, -1)] * max_epoch
writer = SummaryWriter(log_dir=os.path.join(TARGET_DIR, name))
delta_step = dict()
all_wall_time = dict()
max_part = max([x[1] for x in epoch_part])
for epoch, part in epoch_part:
ea = event_accumulator.EventAccumulator(os.path.join(BASE_DIR, get_fullname(name, epoch, part)))
ea.Reload()
for key in ea.scalars.Keys():
if key not in delta_step:
delta_step[key] = 0
all_wall_time[key] = 0
max_step = delta_step[key]
local_step = 0
max_wall_time = 0
scalars = ea.Scalars(key)
first_se = scalars[0]
if max_wall_time != 0:
delta_wall = first_se.wall_time
else:
delta_wall = -1
for scalar_event in scalars:
if delta_wall == -1:
wall_time = scalar_event.wall_time
else:
wall_time = max_wall_time + scalar_event.wall_time - delta_wall
writer.add_scalar(key, scalar_event.value, scalar_event.step + max_step, wall_time)
local_step = max(local_step, scalar_event.step + max_step)
max_wall_time = max(max_wall_time, wall_time)
if part == max_part:
delta_step[key] = local_step
all_wall_time[key] = max_wall_time
writer.close()
def get_max_epoch(groups):
max_epoch = 0
for epoch_parts in groups.values():
for epoch, _ in epoch_parts:
max_epoch = max(max_epoch, epoch)
return max_epoch
def merge():
groups = get_groups()
pprint(groups)
max_epoch = get_max_epoch(groups)
for name, epoch_parts in groups.items():
merge_group(name, epoch_parts, max_epoch)
if __name__ == '__main__':
merge()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment