Created
September 21, 2021 09:22
-
-
Save heiner/c82cfef7cc368b02fd831595f3a1435f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Script for plotting results. | |
``` | |
python plot.py logs.tsv | |
``` | |
""" | |
import argparse | |
import glob | |
import os | |
import gnuplotlib as gp | |
import numpy as np | |
import pandas # Fast CSV reading. | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--xkey", default="step", type=str, help="x values to plot.") | |
parser.add_argument( | |
"--ykey", default="episode_return", type=str, help="y values to plot." | |
) | |
parser.add_argument("--window", default=50, type=int, help="Smoothing window size.") | |
parser.add_argument("--width", default=80, type=int, help="Width of plot.") | |
parser.add_argument("--height", default=30, type=int, help="Height of plot.") | |
parser.add_argument( | |
"--errorbars", action="store_true", help="Whether to print error bars." | |
) | |
parser.add_argument( | |
"--smoothing", | |
default="pandas", | |
choices=["pandas", "convolve", "cumsum"], | |
help="Smoothing algorithm.", | |
) | |
parser.add_argument("files", nargs="*", type=str) | |
def moving_average_cumsum(a, n=20): | |
# Fast, but doesn't play well with NaNs | |
ret = np.cumsum(a, dtype=float) | |
ret[n:] = ret[n:] - ret[:-n] | |
return ret[n - 1 :] / n | |
def moving_average(a, n=20): | |
return np.convolve(a, np.ones((n,)) / n, mode="valid") | |
def rolling_xs_ys(xs, ys, window_size=20): | |
"""Alternative to rolling() in pandas.""" | |
ma = moving_average_cumsum if FLAGS.smoothing == "cumsum" else moving_average | |
return xs[window_size - 1 :], ma(ys, window_size) | |
def plot(xys, xrange=None, yrange=None, color="green"): | |
plot_options = dict( | |
terminal="dumb %d %d ansi" % (FLAGS.width, FLAGS.height), | |
title=FLAGS.ykey, | |
xlabel=FLAGS.xkey, | |
set=("key outside bottom center",), | |
# _with="points linecolor '%s'" % color, | |
) | |
if FLAGS.errorbars: | |
plot_options["with"] = "yerrorbars" | |
if xrange is not None: | |
plot_options.update(xrange=xrange) | |
if yrange is not None: | |
plot_options.update(yrange=yrange) | |
gp.plot(*xys, **plot_options) | |
def load_file(filename): | |
delimiters = {".tsv": "\t", ".csv": ","} | |
_, ext = os.path.splitext(filename) | |
if ext not in delimiters: | |
raise RuntimeError("Filetype not recognised (expected .csv or .tsv): %s" % ext) | |
df = pandas.read_csv(filename, sep=delimiters[ext]) | |
xs = np.array(df[FLAGS.xkey]) | |
if FLAGS.smoothing == "pandas": | |
window = df[FLAGS.ykey].rolling(window=FLAGS.window, min_periods=0) | |
ys = np.array(window.mean()) | |
else: | |
ys = np.array(df[FLAGS.ykey]) | |
xs, ys = rolling_xs_ys(xs, ys, window_size=FLAGS.window) | |
return (xs, ys, {"legend": filename}) | |
def main(): | |
xys = [] | |
for pattern in FLAGS.files: | |
for filename in glob.glob(pattern): | |
xys.append(load_file(filename)) | |
plot(xys) | |
if __name__ == "__main__": | |
global FLAGS | |
FLAGS = parser.parse_intermixed_args() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment