Created
June 12, 2023 09:09
-
-
Save Aunsiels/62bc76e64d1914d0b5fa9f6042a515ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Script to merge several scalars from different experiments. | |
# Experiments that follow the same pattern <name>_epoch<X>_part<Y> will be merged into a single directory <name>. | |
# The _part<Y> is optional | |
# If the experiment name contains no epoch and no part, it will be replicated several times | |
# (= max number of epochs in the experiments) | |
# The program takes as input the original runs directory and the wanted output merged runs directory. | |
# How to use: | |
# python tensorboard_merger.py source target | |
# Example: | |
# python tensorboard_merger.python /tmp/runs /tmp/runs_merged | |
from pprint import pprint | |
from tensorboard.backend.event_processing import event_accumulator | |
from torch.utils.tensorboard import SummaryWriter | |
from sys import argv | |
import os | |
import re | |
BASE_DIR = argv[1] | |
TARGET_DIR = argv[2] | |
RE_GROUP = re.compile(r"(?P<name>.*)_epoch(?P<epoch>\d*)(_part(?P<part>\d*))?/?") | |
def get_groups(): | |
res = {} | |
for dir in os.listdir(BASE_DIR): | |
match = RE_GROUP.search(dir) | |
if match is None: | |
res[dir] = [(-1, -1)] | |
else: | |
name = match.group("name") | |
epoch = int(match.group("epoch")) | |
part = int(match.group("part") or -1) | |
if name not in res: | |
res[name] = [] | |
res[name].append((epoch, part)) | |
for name in res: | |
res[name] = sorted(res[name]) | |
return res | |
def get_fullname(base, epoch, part): | |
if epoch != -1: | |
base += "_epoch" + str(epoch) | |
if part != -1: | |
base += "_part" + str(part) | |
return base | |
def merge_group(name, epoch_part, max_epoch): | |
print("Merging", name) | |
if len(epoch_part) == 1 and epoch_part[0] == (-1, -1): | |
# We have a baseline on a single epoch | |
epoch_part = [(-1, -1)] * max_epoch | |
writer = SummaryWriter(log_dir=os.path.join(TARGET_DIR, name)) | |
delta_step = dict() | |
all_wall_time = dict() | |
max_part = max([x[1] for x in epoch_part]) | |
for epoch, part in epoch_part: | |
ea = event_accumulator.EventAccumulator(os.path.join(BASE_DIR, get_fullname(name, epoch, part))) | |
ea.Reload() | |
for key in ea.scalars.Keys(): | |
if key not in delta_step: | |
delta_step[key] = 0 | |
all_wall_time[key] = 0 | |
max_step = delta_step[key] | |
local_step = 0 | |
max_wall_time = 0 | |
scalars = ea.Scalars(key) | |
first_se = scalars[0] | |
if max_wall_time != 0: | |
delta_wall = first_se.wall_time | |
else: | |
delta_wall = -1 | |
for scalar_event in scalars: | |
if delta_wall == -1: | |
wall_time = scalar_event.wall_time | |
else: | |
wall_time = max_wall_time + scalar_event.wall_time - delta_wall | |
writer.add_scalar(key, scalar_event.value, scalar_event.step + max_step, wall_time) | |
local_step = max(local_step, scalar_event.step + max_step) | |
max_wall_time = max(max_wall_time, wall_time) | |
if part == max_part: | |
delta_step[key] = local_step | |
all_wall_time[key] = max_wall_time | |
writer.close() | |
def get_max_epoch(groups): | |
max_epoch = 0 | |
for epoch_parts in groups.values(): | |
for epoch, _ in epoch_parts: | |
max_epoch = max(max_epoch, epoch) | |
return max_epoch | |
def merge(): | |
groups = get_groups() | |
pprint(groups) | |
max_epoch = get_max_epoch(groups) | |
for name, epoch_parts in groups.items(): | |
merge_group(name, epoch_parts, max_epoch) | |
if __name__ == '__main__': | |
merge() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment