Skip to content

Instantly share code, notes, and snippets.

@zmonoid
Created December 14, 2018 11:29
Show Gist options
  • Save zmonoid/e83d31a5f68ac1b74fa7c769b71ec13b to your computer and use it in GitHub Desktop.
Save zmonoid/e83d31a5f68ac1b74fa7c769b71ec13b to your computer and use it in GitHub Desktop.
MatPlotLib Meter Logger
import numpy as np
import matplotlib.pyplot as plt
import torch
class ValueMeter:
def __init__(self):
self.reset()
def add(self, value):
self.values.append(value)
@property
def value(self):
return self.values[-1]
@property
def mean(self):
return np.mean(self.values)
@property
def std(self):
return np.std(self.values)
def reset(self):
self.values = []
class PLTMeterLogger(object):
''' A class to package and visualize meters.
'''
def __init__(self, name="Main"):
self.meters = {}
self.name = name
plt.rc('text', usetex=True)
plt.rc('font', size=32, family='serif')
def add(self, meter, value):
if isinstance(value, torch.Tensor):
value = value.item()
if meter in self.meters:
self.meters[meter].add(value)
else:
self.meters[meter] = ValueMeter()
self.meters[meter].add(value)
def plot(self, path='./'):
with plt.style.context(('ggplot')):
for k, v in self.meters.items():
plt.figure(figsize=(8, 6))
plt.title(k)
plt.plot(v.values)
plt.savefig('{}/{}.png'.format(path, k))
plt.clf()
def save(self, path='./', hyperparams={}):
np.savez('{}/{}.npz'.format(path, self.name), {k: v.values for k, v in self.meters.items()}, **hyperparams)
def toAcc(self, output, target):
acc = output.argmax(dim=1).eq(target).sum().item() * 100.0 / output.size(0)
return acc
def reset(self):
self.meters = {}
def print(self, start=""):
for k, v in self.meters.items():
start += "{}: {:6.3f}\t".format(k.split(" ")[-1], v.mean)
print(start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment