Created
October 3, 2009 17:06
-
-
Save tkf/200756 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid import AxesGrid | |
import numpy | |
class MultiSeqPlotter(object): | |
""" | |
Reuseable multi-sequence plotter for plotting some time series | |
of some variables. Time series are column-wise and variables are | |
row-wise. | |
""" | |
def __init__( self, plot_table, | |
sum_steps=None, fignum=None, plot_opts = None, | |
xlabels=None, ylabels=None, titles=None, suptitle=None, | |
xlims=None, ylims=None, | |
figsizex_per_steps = 0.04, figsizey_per_rows = 2 ): | |
""" | |
- `plot_table` : ncols x nrows list of (xarray,yarray) data | |
==================== ======= ======================================= | |
Optional Keywords Default Description | |
==================== ======= ======================================= | |
sum_steps None sum of steps in x-axis | |
fignum None it will be passed to `plt.figure` | |
xlabels None list of xlabel(str) | |
ylabels None list of ylabel(str) | |
titles None list of title(str) for each colmun | |
xlims None list of (xmin,xmax) | |
ylims None list of (ymin,ymax) | |
figsizex_per_steps 0.04 figsize[0]=figsizex_per_steps*sum_steps | |
figsizey_per_rows 2 figsize[1]=figsizex_per_rows*nrows | |
==================== ======= ======================================= | |
""" | |
# set vaules | |
self.plot_table = plot_table | |
self.xlabels = xlabels | |
self.ylabels = ylabels | |
self.titles = titles | |
self.xlims = xlims | |
self.ylims = ylims | |
self.figsizex_per_steps = figsizex_per_steps | |
self.figsizey_per_rows = figsizey_per_rows | |
if plot_opts is None: | |
self.plot_opts = [] | |
else: | |
self.plot_opts = plot_opts | |
if sum_steps is None: | |
self.sum_steps = sum([len(col[0][0]) for col in plot_table]) | |
else: | |
self.sum_steps = sum_steps | |
# get figure | |
self.nrows = len(plot_table[0]) | |
self.ncols = len(plot_table) | |
figsize = ( self.figsizex_per_steps * self.sum_steps, | |
self.figsizey_per_rows * self.nrows ) | |
self.fig = plt.figure(fignum, figsize=figsize) | |
self.fig.clf() | |
if suptitle is not None: | |
self.fig.suptitle(suptitle) | |
# get AxesGrid | |
self.grid = AxesGrid( self.fig, 111, | |
nrows_ncols = (self.nrows, self.ncols), | |
axes_pad = 0.0, | |
add_all=True, | |
aspect=False, | |
) | |
self.cla_all() | |
def cla_all(self): | |
"clear all axes and set axes attributes" | |
for ax in self.grid: | |
ax.cla() | |
# set axes attributes | |
self.grid.set_label_mode('L') | |
if self.ylabels is not None: | |
for (row,lb) in zip(self.grid.axes_column[0],self.ylabels): | |
row.set_ylabel(lb) | |
if self.xlabels is not None: | |
for (col,lb) in zip(self.grid.axes_row[-1],self.xlabels): | |
col.set_xlabel(lb) | |
if self.titles is not None: | |
for (col,tt) in zip(self.grid.axes_row[0],self.titles): | |
col.set_title(tt) | |
def plot_all(self): | |
"plot for all axes and set xlim/ylim" | |
for col in zip(self.grid.axes_column, self.plot_table): | |
for (ax, line) in zip(*col): | |
ax.plot(*(list(line) + list(self.plot_opts))) | |
# set xlim/ylim | |
if self.xlims is not None: | |
for (col,xlim) in zip(self.grid.axes_row[-1],self.xlims): | |
col.set_xlim(*xlim) | |
if self.ylims is not None: | |
for (row,ylim) in zip(self.grid.axes_column[0],self.ylims): | |
row.set_ylim(*ylim) | |
list_func = [ lambda x: numpy.sin(x), | |
lambda x: numpy.sin(0.1*x), | |
lambda x: numpy.sin(x)*numpy.sin(0.1*x) ] | |
list_tmax = [30,50,70,100,200] | |
dt = 1 | |
plot_table = [] | |
for tmax in list_tmax: | |
plot_col = [] | |
t = numpy.arange(0,tmax,dt) | |
for f in list_func: | |
plot_col.append((t,f(t))) | |
plot_table.append(plot_col) | |
xlabels = ['time t'] * len(list_tmax) | |
ylabels = ['sin(t)', 'sin(0.1*t)', 'sin(t) * sin(0.1*t)'] | |
titles = ['seq = %d' % i for i in xrange(len(list_tmax))] | |
ylims = [(-1.1,1.1)] * len(list_func) | |
msp = MultiSeqPlotter(plot_table, fignum=1, plot_opts=['.-'], | |
xlabels=xlabels, ylabels=ylabels, titles=titles, | |
suptitle = 'Big Title', ylims=ylims ) | |
msp.plot_all() | |
plt.draw() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment