Skip to content

Instantly share code, notes, and snippets.

@wayofnumbers
Created October 22, 2019 21:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wayofnumbers/e2f21c356b6fd579ffbd86dc8639e0a6 to your computer and use it in GitHub Desktop.
Save wayofnumbers/e2f21c356b6fd579ffbd86dc8639e0a6 to your computer and use it in GitHub Desktop.
FMNIST-RunManager
# Helper class, help track loss, accuracy, epoch time, run time,
# hyper-parameters etc. Also record to TensorBoard and write into csv, json
class RunManager():
def __init__(self):
# tracking every epoch count, loss, accuracy, time
self.epoch_count = 0
self.epoch_loss = 0
self.epoch_num_correct = 0
self.epoch_start_time = None
# tracking every run count, run data, hyper-params used, time
self.run_params = None
self.run_count = 0
self.run_data = []
self.run_start_time = None
# record model, loader and TensorBoard
self.network = None
self.loader = None
self.tb = None
# record the count, hyper-param, model, loader of each run
# record sample images and network graph to TensorBoard
def begin_run(self, run, network, loader):
self.run_start_time = time.time()
self.run_params = run
self.run_count += 1
self.network = network
self.loader = loader
self.tb = SummaryWriter(comment=f'-{run}')
images, labels = next(iter(self.loader))
grid = torchvision.utils.make_grid(images)
self.tb.add_image('images', grid)
self.tb.add_graph(self.network, images)
# when run ends, close TensorBoard, zero epoch count
def end_run(self):
self.tb.close()
self.epoch_count = 0
# zero epoch count, loss, accuracy,
def begin_epoch(self):
self.epoch_start_time = time.time()
self.epoch_count += 1
self.epoch_loss = 0
self.epoch_num_correct = 0
#
def end_epoch(self):
# calculate epoch duration and run duration(accumulate)
epoch_duration = time.time() - self.epoch_start_time
run_duration = time.time() - self.run_start_time
# record epoch loss and accuracy
loss = self.epoch_loss / len(self.loader.dataset)
accuracy = self.epoch_num_correct / len(self.loader.dataset)
# Record epoch loss and accuracy to TensorBoard
self.tb.add_scalar('Loss', loss, self.epoch_count)
self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
# Record params to TensorBoard
for name, param in self.network.named_parameters():
self.tb.add_histogram(name, param, self.epoch_count)
self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
# Write into 'results' (OrderedDict) for all run related data
results = OrderedDict()
results["run"] = self.run_count
results["epoch"] = self.epoch_count
results["loss"] = loss
results["accuracy"] = accuracy
results["epoch duration"] = epoch_duration
results["run duration"] = run_duration
# Record hyper-params into 'results'
for k,v in self.run_params._asdict().items(): results[k] = v
self.run_data.append(results)
df = pd.DataFrame.from_dict(self.run_data, orient = 'columns')
# display epoch information and show progress
clear_output(wait=True)
display(df)
# accumulate loss of batch into entire epoch loss
def track_loss(self, loss):
# multiply batch size so variety of batch sizes can be compared
self.epoch_loss += loss.item() * self.loader.batch_size
# accumulate number of corrects of batch into entire epoch num_correct
def track_num_correct(self, preds, labels):
self.epoch_num_correct += self._get_num_correct(preds, labels)
@torch.no_grad()
def _get_num_correct(self, preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
# save end results of all runs into csv, json for further analysis
def save(self, fileName):
pd.DataFrame.from_dict(
self.run_data,
orient = 'columns',
).to_csv(f'{fileName}.csv')
with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
json.dump(self.run_data, f, ensure_ascii=False, indent=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment