Skip to content

Instantly share code, notes, and snippets.

@JohannesBuchner
Last active April 21, 2022 19:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JohannesBuchner/7bc1728e92ce46ef4e477a5d10085024 to your computer and use it in GitHub Desktop.
Save JohannesBuchner/7bc1728e92ce46ef4e477a5d10085024 to your computer and use it in GitHub Desktop.
Intercept all matplotlib calls and store each figure's data into json files, with labels as keys.
"""
Intercept all matplotlib calls and store each figure's data
into json files, with labels as keys.
Usage:
Just replace:
import matplotlib.pyplot as plt
with:
from mplrecorder import plt
"""
import matplotlib.pyplot as realplt
import json
import numpy as np
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
try:
return json.JSONEncoder.default(self, obj)
except:
return '<not serializable>'
class CallIntercepter():
def __init__(self, origfunction, name, log=[], data={}):
self.origfunction = origfunction
self.name = name
self.log = log
self.data = data
def __call__(self, *args, **kwargs):
if self.name == 'figure':
self.log.append('')
self.log.append('')
# use named figure as title
if len(args) > 0 and isinstance(args[0], str):
self.data['current_title'] = args[0]
elif 'current_title' in self.data:
del self.data['current_title']
elif self.name == 'title':
# use title, if provided, as a key
self.data['current_title'] = str(args[0])
elif self.name == 'subplot':
self.log.append('')
# use subplot arguments (which panel we are working on)
# as the fiducial title
self.data['current_title'] = str(args)
else:
title = self.data.get('current_title', '')
curdata = self.data.get(title, {})
curdata[kwargs.get('label', self.name)] = (
self.name,
args,
kwargs
)
# print('storing', title, curdata)
self.data[title] = curdata
self.log.append('plt.%s(*%s, **%s)' % (self.name, args, kwargs))
if self.name == 'savefig':
with open(args[0] + '.log', 'w') as fout:
fout.write('import matplotlib.pyplot as plt\n')
fout.write('from numpy import array\n')
fout.write('\n')
fout.write('\n'.join(self.log))
fout.write('\n')
with open(args[0] + '.data', 'w') as fout:
json.dump(self.data, fout, indent=4,
cls=NumpyEncoder)
return self.origfunction(*args, **kwargs)
def __getattr__(self, *args, **kwargs):
return getattr(self.origfunction, *args, **kwargs)
class MplRecorder(object):
"""Store matplotlib calls and data."""
def __init__(self):
self.log = []
self.data = {}
def __getattr__(self, name):
obj = getattr(realplt, name)
ret = CallIntercepter(obj, name, log=self.log, data=self.data)
if name == 'close':
self.log = []
self.data = {}
return ret
plt = MplRecorder()
if __name__ == '__main__':
import numpy as np
plt.plot([1, 2], np.array([3,4]), color='g')
plt.savefig("foo.pdf", bbox_inches='tight')
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment