Skip to content

Instantly share code, notes, and snippets.

@sglyon
Last active August 29, 2015 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sglyon/5428ab8d21726acbd22f to your computer and use it in GitHub Desktop.
Save sglyon/5428ab8d21726acbd22f to your computer and use it in GitHub Desktop.
cs2014

Code/data for cs2014

To actually run this code you need to get my gibbs branch of MCMC.jl:

Pkg.clone("git@github.com:spencerlyon2/MCMC.jl.git")
Pkg.checkout("MCMC", "gibbs")

You will also need a special version of the StateSpace.jl package:

Pkg.clone("git@github.com:cc7768/StateSpace.jl.git")
Pkg.checkout("StateSpace", "non-constant")
-.06870229007633588 NaN -.06870229007633588
.032786885245901676 NaN .032786885245901676
.023809523809523725 NaN .023809523809523725
.10077519379844957 NaN .10077519379844957
-.176056338028169 NaN -.176056338028169
.008547008547008517 NaN .008547008547008517
.06779661016949157 NaN .06779661016949157
.11904761904761907 NaN .11904761904761907
-.049645390070921946 NaN -.049645390070921946
-.02985074626865669 NaN -.02985074626865669
-.11538461538461542 NaN -.11538461538461542
.13043478260869557 NaN .13043478260869557
.007692307692307665 NaN .007692307692307665
-.03816793893129766 NaN -.03816793893129766
.039682539682539764 NaN .039682539682539764
.23664122137404586 NaN .23664122137404586
.1234567901234569 NaN .1234567901234569
-.06593406593406592 NaN -.06593406593406592
-.11176470588235299 NaN -.11176470588235299
0 NaN 0
-.026490066225165587 NaN -.026490066225165587
-.1496598639455783 NaN -.1496598639455783
-.15200000000000002 NaN -.15200000000000002
-.037735849056603765 NaN -.037735849056603765
.03921568627450989 NaN .03921568627450989
-.028301886792452824 NaN -.028301886792452824
-.04854368932038833 NaN -.04854368932038833
.05102040816326525 NaN .05102040816326525
-.03883495145631066 NaN -.03883495145631066
-.010101010101010055 NaN -.010101010101010055
-.010204081632653073 NaN -.010204081632653073
-.010309278350515427 NaN -.010309278350515427
-.05208333333333337 NaN -.05208333333333337
.03296703296703307 NaN .03296703296703307
.010638297872340496 NaN .010638297872340496
0 NaN 0
-.052631578947368474 NaN -.052631578947368474
.11111111111111116 NaN .11111111111111116
.1399999999999999 NaN .1399999999999999
.00877192982456143 NaN .00877192982456143
-.04347826086956519 NaN -.04347826086956519
.018181818181818077 NaN .018181818181818077
-.1517857142857143 NaN -.1517857142857143
-.03157894736842104 NaN -.03157894736842104
-.10869565217391308 NaN -.10869565217391308
-.08536585365853655 NaN -.08536585365853655
.026666666666666616 NaN .026666666666666616
.07792207792207795 NaN .07792207792207795
0 NaN 0
.0843373493975903 NaN .0843373493975903
-.0888888888888889 NaN -.0888888888888889
0 NaN 0
.024390243902439046 NaN .024390243902439046
-.011904761904761862 NaN -.011904761904761862
.06024096385542177 NaN .06024096385542177
.10227272727272729 NaN .10227272727272729
.11340206185567014 NaN .11340206185567014
.0185185185185186 NaN .0185185185185186
-.045454545454545414 NaN -.045454545454545414
.05714285714285716 NaN .05714285714285716
-.16216216216216217 NaN -.16216216216216217
.021505376344086002 NaN .021505376344086002
-.021052631578947323 NaN -.021052631578947323
-.043010752688172005 NaN -.043010752688172005
.1685393258426966 NaN .1685393258426966
.27884615384615374 NaN .27884615384615374
.45112781954887216 NaN .45112781954887216
-.041450777202072575 NaN -.041450777202072575
-.05945945945945941 NaN -.05945945945945941
-.06896551724137934 NaN -.06896551724137934
-.024691358024691357 NaN -.024691358024691357
-.044303797468354444 NaN -.044303797468354444
-.10596026490066224 NaN -.10596026490066224
-.03703703703703709 NaN -.03703703703703709
.04615384615384621 NaN .04615384615384621
-.022058823529411797 NaN -.022058823529411797
-.052631578947368474 NaN -.052631578947368474
-.06349206349206349 NaN -.06349206349206349
-.06779661016949157 NaN -.06779661016949157
-.036363636363636376 NaN -.036363636363636376
-.14150943396226412 NaN -.14150943396226412
-.01098901098901095 NaN -.01098901098901095
.11111111111111116 NaN .11111111111111116
.030000000000000027 NaN .030000000000000027
.04854368932038833 NaN .04854368932038833
-.06481481481481477 NaN -.06481481481481477
-.07920792079207917 NaN -.07920792079207917
-.08602150537634412 NaN -.08602150537634412
-.03529411764705881 NaN -.03529411764705881
.03658536585365857 NaN .03658536585365857
.0117647058823529 NaN .0117647058823529
-.05813953488372092 NaN -.05813953488372092
.012345679012345734 NaN .012345679012345734
-.010309278350515427 NaN -.010309278350515427
-.0625 NaN -.0625
.022222222222222143 NaN .022222222222222143
-.09782608695652162 NaN -.09782608695652162
.012048192771084265 NaN .012048192771084265
-.04761904761904767 NaN -.04761904761904767
0 NaN 0
.050000000000000044 NaN .050000000000000044
.0714285714285714 NaN .0714285714285714
.07777777777777772 NaN .07777777777777772
-.020618556701030855 NaN -.020618556701030855
.0736842105263158 NaN .0736842105263158
.009803921568627638 NaN .009803921568627638
0 NaN 0
.009708737864077666 NaN .009708737864077666
.028846153846153744 NaN .028846153846153744
.04672897196261672 NaN .04672897196261672
-.02678571428571419 NaN -.02678571428571419
.07339449541284404 NaN .07339449541284404
.03418803418803429 NaN .03418803418803429
-.07438016528925628 NaN -.07438016528925628
.0625 NaN .0625
.008403361344537785 NaN .008403361344537785
-.025000000000000022 NaN -.025000000000000022
.025641025641025772 NaN .025641025641025772
.22499999999999987 NaN .22499999999999987
.38095238095238115 NaN .38095238095238115
.11330049261083741 NaN .11330049261083741
.05752212389380529 NaN .05752212389380529
.11297071129707126 NaN .11297071129707126
-.368421052631579 NaN -.368421052631579
-.005952380952381042 NaN -.005952380952381042
.041916167664670656 NaN .041916167664670656
-.028735632183908066 NaN -.028735632183908066
.059171597633136175 NaN .059171597633136175
-.03351955307262555 NaN -.03351955307262555
-.05202312138728338 NaN -.05202312138728338
.018292682926829285 NaN .018292682926829285
-.017964071856287456 NaN -.017964071856287456
-.0914634146341462 NaN -.0914634146341462
-.1610738255033557 NaN -.1610738255033557
-.10400000000000009 NaN -.10400000000000009
.017857142857143016 NaN .017857142857143016
.13157894736842102 NaN .13157894736842102
.06976744186046524 NaN .06976744186046524
.007246376811594235 NaN .007246376811594235
.07194244604316546 NaN .07194244604316546
-.0872483221476511 NaN -.0872483221476511
-.022058823529411686 NaN -.022058823529411686
.015037593984962294 NaN .015037593984962294
.11111111111111116 NaN .11111111111111116
.1333333333333333 NaN .1333333333333333
.04705882352941182 NaN .04705882352941182
.00561797752808979 NaN .00561797752808979
.016759776536312998 NaN .016759776536312998
.14285714285714302 NaN .14285714285714302
.23076923076923084 NaN .23076923076923084
.08203125 .07057716436637396 .07057716436637396
-.050541516245487306 -.028127746850278412 -.028127746850278412
.03802281368821303 .01899306602351536 .01899306602351536
.11355311355311337 .09437869822485201 .09437869822485201
-.02631578947368407 -.005947553392808835 -.005947553392808835
-.013513513513513598 -.009790590155017709 -.009790590155017709
.0034246575342467 .0021971985718209908 .0021971985718209908
.0034129692832762792 .0021923814743767256 .0021923814743767256
.03061224489795933 .02789171452009831 .02789171452009831
.02970297029702973 .03724394785847296 .03724394785847296
.012820512820512997 .022313413695819184 .022313413695819184
.0031645569620253333 -.002007024586051065 -.002007024586051065
0 .008295625942684737 .008295625942684737
-.003154574132492094 -.00024931438544006923 -.00024931438544006923
.0031645569620253333 .0032418952618455954 .0032418952618455954
-.003154574132492094 -.0034799900571717 -.0034799900571717
0 .003741581441756514 .003741581441756514
.02215189873417711 .01764413518886654 .01764413518886654
.030959752321981338 .03125763125763137 .03125763125763137
.0030030030030030463 .012313521193463961 .012313521193463961
.023952095808383422 .0290058479532167 .0290058479532167
.040935672514619936 .03659922709706742 .03659922709706742
.03651685393258419 .034649122807017596 .034649122807017596
.03252032520325221 .03009749894022873 .03009749894022873
.044619422572178324 .031069958847736556 .031069958847736556
.1306532663316584 .09159848333665943 .09159848333665943
.18888888888888888 .15319926873857392 .15319926873857392
.09158878504672896 .10779961953075445 .10779961953075445
.046232876712328785 .043932455638237 .043932455638237
.06219312602291338 .06497601096641548 .06497601096641548
.07704160246533132 .07825975028961252 .07825975028961252
.1258941344778255 .11113763877283067 .11113763877283067
.1410419313850062 .13461538461538436 .13461538461538436
.09131403118040082 .09251017896032576 .09251017896032576
.020408163265306145 .04004160166406656 .04004160166406656
.0129999999999999 .016166666666666663 .016166666666666663
.02369200394866744 .020911923896998363 .020911923896998363
-.004821600771456103 .008675395614105685 .008675395614105685
-.029069767441860517 -.013299354941466945 -.013299354941466945
.02594810379241519 .020500403551251045 .020500403551251045
.03988326848249035 .024992091110407788 .024992091110407788
.04957904583723094 .051311728395061706 .051311728395061706
.036541889483065804 .0492477064220187 .0492477064220187
NaN .02161443760492454 .02161443760492454
NaN .012256076686066208 .012256076686066208
NaN .012445887445887704 .012445887445887704
NaN .00634687332977002 .00634687332977002
NaN .01918608510920805 .01918608510920805
NaN .026120375195414214 .026120375195414214
NaN .004062718212403782 .004062718212403782
NaN -.008724789783144504 -.008724789783144504
NaN .018177179667070664 .018177179667070664
NaN .0375845652718616 .0375845652718616
NaN .019379376962086248 .019379376962086248
NaN -.01297009179745312 -.01297009179745312
NaN .03180127205088179 .03180127205088179
NaN .036403814840660687 .036403814840660687
NaN .0488160700258109 .0488160700258109
NaN .029424352664241482 .029424352664241482
NaN .03902920694314527 .03902920694314527
NaN .06392237283049074 .06392237283049074
NaN -.026279911616755247 -.026279911616755247
NaN .04166666666666674 .04166666666666674
NaN .059606025492468184 .059606025492468184
NaN .019421722584313628 .019421722584313628
"""
Code to run the Gibbs sampler
@author : Spencer Lyon <spencer.lyon@stern.nyu.edu>
@date : 2014-11-05 13:29:48
## Notes
* When defining the prior for σ_m I needed the mode of the InverseGamma.
This is equal to β / (α + 1).
"""
import Base: mean
using Distributions
using StateSpace
using MCMC
using HDF5
mean(d::NormalInverseGamma) = (d.mu, mean(InverseGamma(d.shape, d.scale)))
## ----------- ##
#- Handle Data -#
## ----------- ##
# pull in data from pandas
# using Pandas
# cs = Pandas.read_msgpack("cs_data.msg")
# cs_raw = copy(values(cs)) # copy converts to Julia array
cs_raw = readcsv("cs_raw.csv")
# split into training sample and estimation data
train = cs_raw[1:52, 1]
cs_data = cs_raw[53:end, :]
T = size(cs_data, 1)
## ------------- ##
#- Define Priors -#
## ------------- ##
# dict for priors
priors = Dict{Symbol, Distribution}()
priors[:μ0] = Normal(mean(train), 0.15) # good
priors[:π0] = Normal(mean(train), 0.025) # good
priors[:lnr0] = Normal(log(var(train)), 5.0) # good
priors[:lnq0] = Normal(log(var(train)/25), 5.0) # good
priors[:W] = InverseWishart(10.0, 0.05 * eye(2)) # TODO: check this
# priors[:ρ_m] = Normal(0.0, 0.45) # good
# priors[:σ_m] = InverseGamma(2*(0.2)/std(train) - 1, 0.2) # TODO: check this
priors[:ρm_σm] = NormalInverseGamma(0.0, 0.45, 2*(0.2)/std(train) - 1, 0.2)
μ, Eσ_m = mean(priors[:ρm_σm])
priors[:m0] = Normal(0.0, 10*Eσ_m^2 / (1-μ^2))
## -------------- ##
#- Initial values -#
## -------------- ##
function get_init(priors, T)
init = Dict{Symbol, Any}()
init[:π] = rand(priors[:π0], T)
init[:μ] = rand(priors[:μ0], T)
init[:m] = rand(priors[:m0], T)
init[:r] = exp(rand(priors[:lnr0], T))
init[:q] = exp(rand(priors[:lnq0], T))
init[:W] = rand(priors[:W])
init[:ρ_m], init[:σ_m] = rand(priors[:ρm_σm])
return init
end
## --------- ##
#- Data Dict -#
## --------- ##
# construct data Dict for the model
data = Dict()
data["π0"] = mean(priors[:π0])
data["μ0"] = mean(priors[:μ0])
data["m0"] = mean(priors[:m0])
data["T"] = T # defined above when importing data
# Construct ranges of noisy, both, clean dates
noisy_range = 1:(findfirst(!isnan(cs_data[:, 2])) - 1)
both_range = (noisy_range.stop + 1):(findfirst(isnan(cs_data[:, 1])) -1)
clean_range = (both_range.stop + 1):T
## check to make sure ranges are accurate
# "clean" column is nan in all of noisy_range
@assert Base.all(isnan(cs_data[noisy_range, 2]))
# neither column is nan in both_range
@assert !Base.any(isnan(cs_data[both_range, 1:2][:]))
# "noisy" column is nan in clean_range
@assert Base.all(isnan(cs_data[clean_range, 1]))
# Make the TimeVaryingParam for the data
# tt is just the type of each of element in our comprehensions.
# I didn't want to type it three times.
tt = (Vector{Float64}, UnitRange{Int})
y = vcat(tt[([cs_data[t, 1]], t:t) for t in noisy_range], # noisy
tt[(cs_data[t, 1:2][:], t:t) for t in both_range], # both
tt[([cs_data[t, 2]], t:t) for t in clean_range]) # clean
y = TimeVaryingParam(y...)
data["y"] = y
## ----------------------------- ##
#- Measurement Equation Matrices -#
## ----------------------------- ##
# Now the measurement equation. This would be constant but for the
# changing dimensionality. Only need three C, D matrices
C_noisy = Float64[1 0 1]
C_both = Float64[1 0 1; 1 0 0]
C_clean = Float64[1 0 0]
Ct = TimeVaryingParam((C_noisy, noisy_range),
(C_both, both_range),
(C_clean, clean_range))
# measurement error is just machine epsilon
D_noisy = Diagonal(fill(1e-5, 1))
D_both = Diagonal(fill(1e-5, 2))
D_clean = Diagonal(fill(1e-5, 1))
Dt = TimeVaryingParam((D_noisy, noisy_range),
(D_both, both_range),
(D_clean, clean_range))
data["Ct"] = Ct
data["Dt"] = Dt
## ------------- ##
#- Define blocks -#
## ------------- ##
function rq_metropolis_step!(cur_p, current, proposed, t, ζ)
# Covariance matrix of current and proposal
H_proposed = Diagonal(proposed)
H_current = Diagonal(current)
ζ_t = ζ[:, t]
# acceptance probability
αt = sqrt(det(H_current)) / sqrt(det(H_proposed))
αt *= exp(-0.5*ζ_t'*inv(H_proposed)*ζ_t)[1]
αt /= exp(-0.5*ζ_t'*inv(H_current)*ζ_t)[1]
# update v
if αt > rand()
cur_p[:r][t] = proposed[1]
cur_p[:q][t] = proposed[2]
else
cur_p[:r][t] = current[1]
cur_p[:q][t] = current[2]
end
nothing
end
function r_q_block!(m::MCMCGibbsModel)
T = m.data["T"]
cur_p = m.curr_params # we will use this a lot. Pull out to simplify notation
# fill in measurement (first row) and state innovations (second row)
ζ = Array(Float64, 2, T)
ζ[1, :] = cur_p[:π] - cur_p[:μ]
ζ[2, 2:end] = cur_p[:μ][2:end] - cur_p[:μ][1:end-1]
ζ[2, 1] = cur_p[:μ][1] - m.data["μ0"]
W = m.curr_params[:W] # get latest W
Ω = W .* 0.5 # proposal variance. This is constant across time.
# allocate space for time-t mean vector
μt = Array(Float64, 2)
# run single move sampler on interior points
for t=2:T-1
# Pull out data. Construct the mean
μt[1] = (log(cur_p[:r][t-1]) + log(cur_p[:r][t+1])) * 0.5
μt[2] = (log(cur_p[:q][t-1]) + log(cur_p[:q][t+1])) * 0.5
v_old = [cur_p[:r][t], cur_p[:q][t]] # current
# generate proposal
lnv_star = rand(MvNormal(μt, Ω)) # proposal
v_star = exp(lnv_star)
# do metropolis step (updates cur_p[:r][t] and cur_p[:q][t])
rq_metropolis_step!(cur_p, v_old, v_star, t, ζ)
end
# TODO: figure out how to deal with end points. Right now I just use
# a single observation. t is period of cur_p[:r], cur_p[:q] we are filling
# in and cond_t is the date we are conditioning on.
# TODO: both 2, T-1 have been updated this scan already. Should I worry?
for (t, cond_t) = [(1, 2), (T, T-1)]
μt[1], μt[2] = log(cur_p[:r][cond_t]), log(cur_p[:q][cond_t])
v_old = [cur_p[:r][t], cur_p[:q][t]] # current
lnv_star = rand(MvNormal(μt, Ω)) # proposal
v_star = exp(lnv_star)
rq_metropolis_step!(cur_p, v_old, v_star, t, ζ)
end
nothing
end
function W_block!(m::MCMCGibbsModel)
# standard InverseWishart-Normal conjugate pair stuff.
r, q = m.curr_params[:r], m.curr_params[:q]
η_r = log(r[2:end]) - log(r[1:end-1])
η_q = log(q[2:end]) - log(q[1:end-1])
# prior is mean zero, so just need η'η in update rule
η = [η_r η_q]
old_dist = m.dists[:W]
new_dist = InverseWishart(m.data["T"] + old_dist.df,
full(old_dist.Ψ) + η'η)
m.curr_params[:W] = rand(new_dist)
nothing
end
function pi_mu_m_block!(m::MCMCGibbsModel)
# forward filter backward sampler
# Pull out current parameter values we need to condition on
ρ_m = m.curr_params[:ρ_m]
σ_m = m.curr_params[:σ_m]
r = m.curr_params[:r]
q = m.curr_params[:q]
# pull out other data
T = m.data["T"]
y = m.data["y"]
# Define state space matrices
A = Float64[0 1 0
0 1 0
0 0 ρ_m]
# Create space for time-dependent state space matrices
Bt = Array((Array{Float64, 2}, UnitRange{Int}), T)
for t = 1:T
# Fill in period t state covariance matrix
this_B_t = Float64[sqrt(r[t]) sqrt(q[t]) 0
0 sqrt(q[t]) 0
0 0 σ_m]
Bt[t] = (this_B_t, t:t)
end
Bt = TimeVaryingParam(Bt...)
Ct, Dt = m.data["Ct"], m.data["Dt"]
# Finally construct the state space model
lgss = LinearGaussianSSabcd(A, Bt, Ct, Dt)
# initial mean and state covariance
μ0 = Float64[mean(priors[i]) for i in [:π0, :μ0, :m0]]
# μ0 = Float64[mean(m.priors[i]) for i in [:π0, :μ0, :m0]]
Σ0 = eye(3)
x0 = MvNormal(μ0, Σ0)
# Now run the fwfilter_bwsampler
new_state = fwfilter_bwsampler(lgss, y, x0)
m.curr_params[:π] = squeeze(new_state[1, :], 1)
m.curr_params[:μ] = squeeze(new_state[2, :], 1)
m.curr_params[:m] = squeeze(new_state[3, :], 1)
nothing
end
function rhom_sigmam_block!(m::MCMCGibbsModel)
# NOTE: In multi-dim version V is cov matrix. So I think in scalar version
# v should be variance
# pull out original params
pri = m.dists[:ρm_σm]
μ = pri.mu
v = pri.v0
α = pri.shape
β = pri.scale
# pull out data
m_t = m.curr_params[:m][2:end]
Lm = m.curr_params[:m][1:end-1] # lagged m (m_{t-1})
T = m.data["T"] - 1
# Update suffstats
vp = inv(v + dot(Lm, Lm))
μp = vp * (v*μ + dot(Lm, m_t))
αp = α + T/2
βp = β + 0.5(dot(m_t, m_t) + μ'inv(v)*μ - μp'inv(vp)*μp)
# sample
new_dist = NormalInverseGamma(μp, vp, αp, βp)
m.curr_params[:ρ_m], m.curr_params[:σ_m] = rand(new_dist)
nothing
end
## --------------------------------------------------- ##
#- Construct Dict mapping param names to block numbers -#
## --------------------------------------------------- ##
block_sym2num = Dict{Union(Symbol, Vector{Symbol}), Int}()
block_sym2num[[:r, :q]] = 1
block_sym2num[:W] = 2
block_sym2num[[:π, :μ, :m]] = 3
block_sym2num[[:ρ_m, :σ_m]] = 4
## ------------------------------------------------- ##
#- Construct Dict mapping block numbers to functions -#
## ------------------------------------------------- ##
block_funcs = Dict{Int, Function}()
block_funcs[1] = r_q_block!
block_funcs[2] = W_block!
block_funcs[3] = pi_mu_m_block!
block_funcs[4] = rhom_sigmam_block!
## ---------------------------- ##
#- Construct the MCMCGibbsModel -#
## ---------------------------- ##
# m = MCMCGibbsModel(get_init(priors, T), priors, block_sym2num, block_funcs, data, 4)
## --------------- ##
#- Run the sampler -#
## --------------- ##
#=
_allocate is a helper function that the sampler code calls when
allocating memory for samples of various types.
This method has been implemented for scalars, vectors, and matrices
The trailing dimension of the resultant array will be the length of the
chain T. The leading dimensions will match size(x).
This should be used in conjunction with the _save! method
=#
_allocate{S <: Number}(::S, T::Int) = Array(S, T)
_allocate{S <: Number}(x::Vector{S}, T::Int) = Array(S, length(x), T)
_allocate{S <: Number}(x::Matrix{S}, T::Int) = Array(S, size(x, 1),
size(x, 2), T)
_save!(d::Dict, nm::Symbol, x::Real, t::Int) = d[nm][t] = x
_save!(d::Dict, nm::Symbol, x::Vector, t::Int) = d[nm][:, t] = x
_save!(d::Dict, nm::Symbol, x::Matrix, t::Int) = d[nm][:, :, t] = x
function _save_text!(io::IOStream, nm::Symbol, x::Real, t::Int)
nothing
end
function _save_text!(d::Dict, nm::Symbol, x::Vector, t::Int)
nothing
end
function _save_text!(d::Dict, nm::Symbol, x::Matrix, t::Int)
nothing
end
function run_me(m::MCMCGibbsModel, r::Range=1:200)
param_names = keys(m.curr_params)
samples = Dict{Symbol, Array}()
T = length(r)
for nm in param_names
samples[nm] = _allocate(m.curr_params[nm], T)
end
# Print 10 times during simulation
print_checks = int([(1:10) * (last(r)/10)])
t = 1
for i=1:last(r)
# update all blocks
for blk=1:m.n_block
m.block_funcs[blk](m)
end
# if we aren't thinning this time
if i in r
# extract new sample
for nm in param_names
_save!(samples, nm, m.curr_params[nm], t)
end
t += 1
end
if i in print_checks
println("Finished with $i of $(last(r))")
end
end
samples
end
run_me(m::MCMCGibbsModel, t::Int) = run_me(m, 1:t)
## ------------------------------------------------------- ##
#- Running method to save intermediate results to hdf file -#
## ------------------------------------------------------- ##
"""
functions to construct groups for HDF5 dataset.
* f is file name
* nm is dataset name at root level or with `group/name` syntax
* T is integer for length of dataset
* cs is chunksize
"""
function _allocate_hdf{S <: Number}(::S, f::HDF5File, nm::String, T::Int,
cs::Int)
d = d_create(f, nm, datatype(S), dataspace((T, )), "chunk", (cs))
a = Array(S, cs)
return d, a
end
function _allocate_hdf{S <: Number}(x::Vector{S}, f::HDF5File, nm::String,
T::Int, cs::Int)
s1 = length(x)
d = d_create(f, nm, datatype(S), dataspace(s1, T), "chunk", (s1, cs))
a = Array(S, s1, cs)
return d, a
end
function _allocate_hdf{S <: Number}(x::Matrix{S}, f::HDF5File, nm::String,
T::Int, cs::Int)
s1, s2 = size(x)
d = d_create(f, nm, datatype(S), dataspace(s1, s2, T), "chunk",
(s1, s2, cs))
a = Array(S, s1, s2, cs)
return d, a
end
function _save_hdf!(g::HDF5Dataset, x::Vector, t::UnitRange)
g[t] = x
nothing
end
function _save_hdf!(g::HDF5Dataset, x::Matrix, t::UnitRange)
g[:, t] = x
nothing
end
function _save_hdf!{T<:Real}(g::HDF5Dataset, x::Array{T, 3}, t::UnitRange)
g[:, :, t] = x
nothing
end
# crazy default for cs just says if chunk size is not specified, then we
# should should try to have at least 20 chunks. However, we bound the
# chunksize to be between (min(length(r), 400), 1000).
function run_hdf(m::MCMCGibbsModel, r::Range=1:200;
filename::String="cs2014_$(myid()).h5",
cs::Int=0)
if cs == 0
# set real default for cs
cs = min(max(min(int(length(r)/20), 1000), 500), length(r))
end
param_names = keys(m.curr_params)
samples = Dict{Symbol, Array}()
dsets = Dict{Symbol, HDF5Dataset}()
f = h5open(filename, "w")
T = length(r)
for nm in param_names
dsets[nm], samples[nm] = _allocate_hdf(m.curr_params[nm], f,
string(nm), T, cs)
end
# Print 10 times during simulation
print_skip = int(last(r)/10)
# chunker keep track of which element of the chunk we are on
chunker = 1
# fill range is what we will eventually pass to _save_hdf!.
fill_range = 1:cs
for i=1:last(r)
# update all blocks
for blk=1:m.n_block
m.block_funcs[blk](m)
end
# if we aren't thinning this time
if i in r
# if our chunk still isn't totally full, grab new samples
if chunker <= cs
for nm in param_names
_save!(samples, nm, m.curr_params[nm], chunker)
end
chunker += 1
end
# check to see if chunk is full.
if chunker == cs + 1
# if it is, save the data for this chunk
for nm in param_names
_save_hdf!(dsets[nm], samples[nm], fill_range)
end
# increment the fill_range
fill_range += cs
# reset the chunker
chunker = 1
end
end
if i % print_skip == 0
println("Finished with $i of $(last(r))")
end
end
close(f)
nothing
end
run_hdf(m::MCMCGibbsModel, t::Int) = run_hdf(m, 1:t)
## -------------------------------------------------------------- ##
#- One more useful function to concat my parallely obtained data -#
## -------------------------------------------------------------- ##
"""
This function will combine an arbitrary number of *similar* h5 files
into one. The notion of similar is that the names, sizes, and types of
all variables are the same in all files.
This function "concatenates" along the last dimension. For example
if you have N files and one dataset has dimensions (i, j) then the
resultant dataset will have dimensions (i, Nj)
"""
function combine_data{XX <: String}(main_fname::String, f_names::Vector{XX})
f = h5open(main_fname, "w")
N_files = length(f_names)
# use this file for reference
f2 = h5open(f_names[1], "r")
# extract useful info about objects from first file
nms = names(f2)
info_dict = Dict()
for nm in nms
this_nm = Dict()
sz = size(f2[nm])
this_nm["size"] = sz
this_nm["ndims"] = length(sz)
this_nm["eltype"] = eltype(f2[nm][ind2sub(sz, 1)...])
info_dict[nm] = this_nm
end
close(f2)
# loop over all names and stitch together the different files
for nm in nms
sz = info_dict[nm]["size"]
nobs = sz[end] # How many observations in this dataset?
inds = 1:nobs # will allow us to fill in later
# construct data type and space, then dataset
dtype = datatype(info_dict[nm]["eltype"])
dspace = dataspace(tuple(sz[1:end-1]..., N_files*sz[end]))
d = d_create(f, nm, dtype, dspace, "chunk", sz)
for fnm in f_names
this_f = h5open(fnm, "r")
_save_hdf!(d,
this_f[nm][[Colon() for i=1:info_dict[nm]["ndims"]]...],
inds)
inds += nobs
end
end
close(f)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment