Skip to content

Instantly share code, notes, and snippets.

@cxy1997
Created October 18, 2018 01:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cxy1997/7caf714911816b32fed6481b84d545a1 to your computer and use it in GitHub Desktop.
Save cxy1997/7caf714911816b32fed6481b84d545a1 to your computer and use it in GitHub Desktop.
Plot curves with matplotlib
# -*- coding: utf-8 -*-
from __future__ import division, print_function
import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn
def smooth(array, m=3):
_array = copy.deepcopy(array)
std = np.zeros_like(array)
n = _array.shape[0]
for i in range(1, n):
_array[i] = np.mean(array[max(0, i - m): min(n, i + m + 1)])
std[i] = np.std(array[max(0, i - m): min(n, i + m + 1)])
return _array, std
def cut(array, m1=300, m2=200, l=30):
n = array.shape[0]
start = max(min(int((n-l)/2), m1), 0)
end = min(max(int((n+l)/2), n-m2), n)
return array[start:end].mean()
def init_plot():
plt.style.use('seaborn-darkgrid')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Ubuntu'
plt.rcParams['font.monospace'] = 'Ubuntu Mono'
plt.rcParams['font.size'] = 20
plt.rcParams['axes.labelsize'] = 20
# plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 20
plt.rcParams['figure.titlesize'] = 20
width, height = plt.figaspect(0.68)
return plt.figure(figsize=(width, height), dpi=200)
if __name__ == '__main__':
fig = init_plot()
# plt.plot(steps, rewards)
plt.xlabel(r'Steps / $10^5$')
plt.ylabel(r'Reward / $10^2$')
plt.legend(['model_name'])
plt.tight_layout()
plt.savefig('reward.png', dpi=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment