Skip to content

Instantly share code, notes, and snippets.

@xiangze
Last active August 29, 2015 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xiangze/9579536 to your computer and use it in GitHub Desktop.
Save xiangze/9579536 to your computer and use it in GitHub Desktop.
Time series visualization class for pystan sample http://xiangze.hatenablog.com/entry/2014/03/16/155826
import numpy as np
class stanTSdraw:
def __init__(self,name,fit):
import pystan
self.fname=name
self.fit=fit
def dumpsummary(self):
with open("summary_"+self.fname+".log",'w') as f:
print >>f, self.fit.summary()
def dumpmean(self):
with open("meanv_"+self.fname+".log",'w') as f:
print >>f, self.fit.get_posterior_mean()
def dumpquantile(self,qtop=97.5,qbuttom=2.5,qmean=50):
qq=str(qtop)+"_"+str(qmean)+"_"+str(qbuttom)
for v in self.fit.extract().keys():
ff=fit.extract(v)[v]
if(ff.ndim>1):
result=zip(*ff)
else:
result=ff
with open(v+"_"+"quantile_"+qq+"_"+self.fname+".log",'w') as f:
for s in result:
q=mquantiles(s,[qtop,qmean,qbuttom])
print >>f, q
class stanTSdrawFromfile:
def __init__(self,name,dir=""):
self.fname=name
self.dir=dir
def draw(self,v="trend",qq="0.975_0.5_0.25",minlen=1,fromN=1973,toN=2013):
import matplotlib.pyplot as plt
import re
filename=self.dir+"/"+v+"_"+"quantile_"+qq+"_"+self.fname
try:
f= open(filename+".log")
ls=f.readlines()
data=[]
for l in ls:
l=re.sub('^[\s]*\[[\s]*',"",l)
l=l.replace(']',"")
data.append(re.split('[\s]*\,[\s]*',l))
# x=range(len(data))
x=[ n/12.+fromN for n in range(len(data))]
if(len(x)>=minlen):
data=zip(*data)
fig, a = plt.subplots(1,sharex=True)
a.fill_between(x, data[0],data[2], color='blue',alpha=0.5)
plt.title(v+" "+str(fromN)+"-"+str(toN))
a.plot(x, data[1], lw=2, label=self.fname, color='black')
plt.savefig(filename+'.png')
else:
print "The time series is too short."
except IOError as (errno, strerror):
print "I/O error({0}): {1} {2}".format(errno, strerror,filename+"log")
if __name__ == '__main__':
import glob
files=glob.glob("log4/summary_kion_hensa_model*.log")
for f in files:
f=f.replace("log4/summary_","")
f=f.replace(".log","")
s=stanTSdrawFromfile(f,dir="log4")
fromN=1973
toN=2013
for v in ['trend','mm','ar','c_ar','s_trend','s_ar','s_mm','s_tot']:
s.draw(v=v,minlen=10,fromN=fromN,toN=toN)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment