Skip to content

Instantly share code, notes, and snippets.

@sunhwan
Last active January 2, 2016 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sunhwan/8363019 to your computer and use it in GitHub Desktop.
Save sunhwan/8363019 to your computer and use it in GitHub Desktop.
WHAM
import sys
import numpy as np
debug = False
if len(sys.argv) > 1: debug = True
input = sys.stdin
pmf_filename = input.readline().strip() # stores pmf
rho_filename = input.readline().strip() # stores average density
bia_filename = input.readline().strip() # stores biased distribution
fff_filename = input.readline().strip() # stores F(i)
temperature = float(input.readline().strip())
xmin, xmax, delta, is_x_periodic = map(float, input.readline().strip().split())
nwin, niter, fifreq = map(int, input.readline().strip().split())
tol = map(float, input.readline().strip().split())
is_x_periodic = bool(is_x_periodic)
nbin = int((xmax-xmin+0.5*delta)/delta)
kb = 0.0019872
kbt = kb * temperature
beta = 1.0/kbt
k1 = np.zeros(nwin)
cx1 = np.zeros(nwin)
tseries = np.empty(nwin, dtype='S')
hist = np.zeros((nwin, nbin), dtype=np.int)
nb_data = np.zeros(nwin, dtype=np.int)
x1 = lambda j: xmin + (j+1)*delta - 0.5*delta
for i in range(nwin):
fname = input.readline().strip()
tseries[i] = fname
cx1[i], k1[i] = map(float, input.readline().strip().split())
def mkhist(fname, xmin, xmax, delta, ihist):
xdata = []
for line in open(fname):
time, x = map(float, line.strip().split()[:2])
xdata.append(x)
x = np.array(xdata)
xbins = [xmin+i*delta for i in range(nbin+1)]
hist[ihist], edges = np.histogram(x, bins=xbins, range=(xmin, xmax))
# add data points outside of the given range
hist[ihist][0] += len(x[np.where(x < xbins[0])])
hist[ihist][-1] += len(x[np.where(x > xbins[-1])])
nb_data[ihist] = int(np.sum(hist[ihist,:]))
print 'statistics for timeseries # ', ihist
print 'minx:', '%8.3f' % np.min(x), 'maxx:', '%8.3f' % np.max(x)
print 'average x', '%8.3f' % np.average(x), 'rms x', '%8.3f' % np.std(x)
print 'statistics for histogram # ', ihist
print int(np.sum(hist[ihist,:])), 'points in the histogram'
print 'average x', '%8.3f' % (np.sum([hist[ihist,i]*(edges[i]+edges[i+1])/2 for i in range(nbin)])/np.sum(hist[ihist]))
print
mkhist(fname, xmin, xmax, delta, i)
if debug: break
# write biased distribution
f = open(bia_filename, 'w')
for j in range(nbin):
f.write("%8d\n" % np.sum(hist[:,j]))
# iterate wham equation to unbias and recombine the histogram
TOP = np.zeros(nbin, dtype=np.int32)
BOT = np.zeros(nbin)
rho = np.zeros(nbin)
V1 = np.zeros((nwin, nbin))
F = np.zeros(nwin)
F2 = np.zeros(nwin)
for i in range(nwin):
for j in range(nbin):
V1[i,j] = k1[i]*(x1(j) - cx1[i])**2
TOP[j] += hist[i,j]
icycle = 1
while icycle < niter:
for j in range(nbin):
BOT = np.sum(nb_data * np.exp(beta*(F-V1[:,j])))
rho[j] = TOP[j] / BOT
F2 = F2 + rho[j]*np.exp(-beta*V1[:,j])
converged = True
F2 = -kbt * np.log(F2)
diff = np.max(np.abs(F2 - F))
if diff > tol: converged = False
print 'round = ', icycle, 'diff = ', diff
icycle += 1
if ( fifreq != 0 and icycle % fifreq == 0 ) or ( icycle == niter or converged ):
open(fff_filename, 'w').write("%8i %s\n" % (icycle, " ".join(["%8.3f" % f for f in F2])))
if icycle == niter or converged: break
F = F2
F2 = np.zeros(nwin)
# find maximum rho
jmin = np.argmax(rho)
rhomax = rho[jmin]
# jacobian
for i in range(nbin):
rho[i] = rho[i] / x1(i)**2
rhomax = np.sum(rho[nbin-5:])/5
print 'maximum density at: x = ', x1(jmin)
# make PMF from the rho
np.seterr(divide='ignore')
pmf = -kbt * np.log(rho/rhomax)
open(pmf_filename, 'w').write("\n".join(["%8.3f %12.3f" % (x1(j), pmf[j]) for j in range(nbin)]))
open(rho_filename, 'w').write("\n".join(["%8.3f %12.3f" % (x1(j), rho[j]) for j in range(nbin)]))
temps = [283.15, 286.83, 290.55, 294.33, 298.15]
xmin = 3.5
xmax = 10.5
nwin = 15
xbin = (0.5, 0.1)
umin = -7
umax = 5
ubin = 20
fc = 1.25
xbuf = 0.25
for i,temp in enumerate(temps):
whamfile = 'wham.%d.list' % i
fp = open(whamfile, 'w')
fp.write("""%d/run.pmf
%d/run.rho
%d/run.bia
%d/run.fff
%8.3f
%8.3f %8.3f %8.3f %8.3f
%8d %8d %8d
0.000001
""" % (i, i, i, i, temp, xmin-xbuf, xmax+2*xbuf, xbin[1], 0.0, nwin, 10000, 100))
for j in range(nwin):
fp.write("""%d/%d.traj
%8.3f %8.3f
""" % (i, j, xmin+j*xbin[0], fc))
whamfile = 'wham.list'
fp = open(whamfile, 'w')
fp.write("""run.pmf
run.rho
run.bia
run.fff
%8.3f
%8.3f %8.3f %8.3f %8.3f
%8.3f %8.3f %8.3f %8.3f
%8d %8d %8d
0.000001
""" % (temps[2], xmin-xbuf, xmax+xbuf, xbin[1], 0.0, umin-100, umax+100, ubin, len(temps), nwin, 10000, 100))
for i,temp in enumerate(temps):
for j in range(nwin):
fp.write("""%d/%d.temp
%8.3f %8.3f %8.3f
""" % (i, j, xmin+j*xbin[0], fc, temp))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment