Created
October 22, 2019 21:56
-
-
Save wayofnumbers/e2f21c356b6fd579ffbd86dc8639e0a6 to your computer and use it in GitHub Desktop.
FMNIST-RunManager
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Helper class, help track loss, accuracy, epoch time, run time, | |
# hyper-parameters etc. Also record to TensorBoard and write into csv, json | |
class RunManager(): | |
def __init__(self): | |
# tracking every epoch count, loss, accuracy, time | |
self.epoch_count = 0 | |
self.epoch_loss = 0 | |
self.epoch_num_correct = 0 | |
self.epoch_start_time = None | |
# tracking every run count, run data, hyper-params used, time | |
self.run_params = None | |
self.run_count = 0 | |
self.run_data = [] | |
self.run_start_time = None | |
# record model, loader and TensorBoard | |
self.network = None | |
self.loader = None | |
self.tb = None | |
# record the count, hyper-param, model, loader of each run | |
# record sample images and network graph to TensorBoard | |
def begin_run(self, run, network, loader): | |
self.run_start_time = time.time() | |
self.run_params = run | |
self.run_count += 1 | |
self.network = network | |
self.loader = loader | |
self.tb = SummaryWriter(comment=f'-{run}') | |
images, labels = next(iter(self.loader)) | |
grid = torchvision.utils.make_grid(images) | |
self.tb.add_image('images', grid) | |
self.tb.add_graph(self.network, images) | |
# when run ends, close TensorBoard, zero epoch count | |
def end_run(self): | |
self.tb.close() | |
self.epoch_count = 0 | |
# zero epoch count, loss, accuracy, | |
def begin_epoch(self): | |
self.epoch_start_time = time.time() | |
self.epoch_count += 1 | |
self.epoch_loss = 0 | |
self.epoch_num_correct = 0 | |
# | |
def end_epoch(self): | |
# calculate epoch duration and run duration(accumulate) | |
epoch_duration = time.time() - self.epoch_start_time | |
run_duration = time.time() - self.run_start_time | |
# record epoch loss and accuracy | |
loss = self.epoch_loss / len(self.loader.dataset) | |
accuracy = self.epoch_num_correct / len(self.loader.dataset) | |
# Record epoch loss and accuracy to TensorBoard | |
self.tb.add_scalar('Loss', loss, self.epoch_count) | |
self.tb.add_scalar('Accuracy', accuracy, self.epoch_count) | |
# Record params to TensorBoard | |
for name, param in self.network.named_parameters(): | |
self.tb.add_histogram(name, param, self.epoch_count) | |
self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count) | |
# Write into 'results' (OrderedDict) for all run related data | |
results = OrderedDict() | |
results["run"] = self.run_count | |
results["epoch"] = self.epoch_count | |
results["loss"] = loss | |
results["accuracy"] = accuracy | |
results["epoch duration"] = epoch_duration | |
results["run duration"] = run_duration | |
# Record hyper-params into 'results' | |
for k,v in self.run_params._asdict().items(): results[k] = v | |
self.run_data.append(results) | |
df = pd.DataFrame.from_dict(self.run_data, orient = 'columns') | |
# display epoch information and show progress | |
clear_output(wait=True) | |
display(df) | |
# accumulate loss of batch into entire epoch loss | |
def track_loss(self, loss): | |
# multiply batch size so variety of batch sizes can be compared | |
self.epoch_loss += loss.item() * self.loader.batch_size | |
# accumulate number of corrects of batch into entire epoch num_correct | |
def track_num_correct(self, preds, labels): | |
self.epoch_num_correct += self._get_num_correct(preds, labels) | |
@torch.no_grad() | |
def _get_num_correct(self, preds, labels): | |
return preds.argmax(dim=1).eq(labels).sum().item() | |
# save end results of all runs into csv, json for further analysis | |
def save(self, fileName): | |
pd.DataFrame.from_dict( | |
self.run_data, | |
orient = 'columns', | |
).to_csv(f'{fileName}.csv') | |
with open(f'{fileName}.json', 'w', encoding='utf-8') as f: | |
json.dump(self.run_data, f, ensure_ascii=False, indent=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment