Skip to content

Instantly share code, notes, and snippets.

@sunhwan
Last active May 15, 2020 20:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sunhwan/70e480eab96427cfe0b3c95820df3e0d to your computer and use it in GitHub Desktop.
Save sunhwan/70e480eab96427cfe0b3c95820df3e0d to your computer and use it in GitHub Desktop.
compute number density of probe along Z axis from Gromacs trajectories
import argparse
from pathlib import Path
import re
import sys
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm
import mdtraj as md
import numpy as np
MASS = {
'H': 1.0,
'C': 12.0,
'F': 18.0,
'N': 16.0,
'P': 31,
'VS': 0.0,
'O': 16.0,
}
def histogram(resname, atomname, symmetry, center_option, ranges, n_bins, use_pbc, use_com, trajfile):
xtc, pdb = trajfile
total_frames = 0
unitcell_lengths = []
# read trajectory
t = md.load(str(xtc), top=str(pdb))
n_frames = t.n_frames
total_frames += n_frames
df, bonds = t.top.to_dataframe()
#
# compute center position
#
if center_option == 'box':
center = (t.unitcell_lengths / 2)[:, np.newaxis, :]
else:
if center_option == 'lipid':
index = ((df['resName'] == 'DOPC') | \
(df['resName'] == 'POPC') | \
(df['resName'] == 'DPPC') | \
(df['resName'] == 'CHL1'))
else:
index = df['resName'] == center_option
# select atom residues
df_residues = df[index]
if (len(df_residues) == 0):
print('no residues with resname =', center_option, 'selected for center group; exit')
return None
# select atom indices for each residue (to compute com of each residue)
atom_indices = df_residues.index.to_list()
mass = [MASS[_] for _ in df_residues.element.to_list()]
# compute com of each residue
coords = t.xyz[:, atom_indices]
center = np.average(coords, axis=1, weights=mass)[:, np.newaxis, :]
#
# compute center of mass of each residues of interest
#
# select atom residues
index = df['resName'] == resname
if atomname:
index = index & (df['name'] == atomname)
df_residues = df[index]
resids = df_residues.resSeq.unique()
n_residues = len(resids)
if n_residues == 0:
print('no residues with resname =', resname, 'and atomname =', atomname, 'selected; exit')
return None
# select atom indices for each residue (to compute com of each residue)
atom_indices = []
mass = []
resSeq = None
indices = []
for idx, row in df_residues.iterrows():
# due to numbering limitation in PDB, two residues
# may have the same resSeq value
if resSeq is None or resSeq == row.resSeq:
indices.append(idx)
else:
atoms = df_residues.loc[indices]
atom_indices.append([_ for _ in atoms.index.to_list()])
mass.append([MASS[_] for _ in atoms.element.to_list()])
indices = [idx]
resSeq = row.resSeq
# compute com of each residue
# coords has shape [<residue>, <frame>, <atom>, <xyz>]
coords = [t.xyz[:, indices] - center for indices in atom_indices]
if use_pbc:
z_up = np.copy(coords)
z_up[:, :, :, 2] += t.unitcell_lengths[np.newaxis, :, np.newaxis, 2]
z_down = np.copy(coords)
z_down[:, :, :, 2] -= t.unitcell_lengths[np.newaxis, :, np.newaxis, 2]
coords = np.concatenate([coords, z_up, z_down], axis=0)
if use_com:
com = np.average(coords, axis=2, weights=mass[0])
# test numpy array math holds up (optional)
test = np.sum([mass[0][i]/sum(mass[0])*xyz for i, xyz in enumerate(coords[0][0]) if mass[0][i] > 0], axis=0)
np.testing.assert_allclose(test, com[0, 0], rtol=1e-5)
# compute number density across Z
if use_com:
hist = np.array([np.histogram(com[:, i, 2], bins=n_bins, range=ranges)[0] for i in range(n_frames)])
else:
hist = np.array([np.histogram(coords[:, i, :, 2], bins=n_bins, range=ranges)[0] for i in range(n_frames)])
height = (ranges[1] - ranges[0]) / (n_bins - 1)
slab_volume = t.unitcell_lengths[:, 0] * t.unitcell_lengths[:, 1] * height
hist = hist / slab_volume[:, np.newaxis]
return hist
def compute_density(trajfiles, resname, atomname, symmetry, center_option, n_bins, use_pbc, use_com, use_thread):
ranges = (-4, 4)
if n_bins % 2 == 0:
n_bins += 1
bins = np.histogram_bin_edges([], n_bins, ranges)
hist = []
if use_thread:
with Pool() as p:
n_trajfiles = len(trajfiles)
func = partial(histogram, resname, atomname, symmetry, center_option, ranges, n_bins, use_pbc, use_com)
for count in tqdm(p.imap_unordered(func, trajfiles), total=n_trajfiles):
if count is None:
sys.exit(1)
hist.append(count)
else:
for trajfile in tqdm(trajfiles):
count = histogram(resname, atomname, symmetry, center_option, ranges, n_bins, use_pbc, use_com, trajfile)
if count is None:
sys.exit(1)
hist.append(count)
hist = np.concatenate(hist)
total_frames = hist.shape[0]
hist = np.sum(hist, axis=0)
if symmetry:
hist = (hist + hist[::-1]) / 2
# print histogram
z_bin_center = (bins[1:] + bins[:-1]) / 2
for i in range(n_bins):
print(z_bin_center[i], hist[i]/total_frames)
def main():
parser = argparse.ArgumentParser(description='process trajectory and compute densities along Z-axis')
parser.add_argument('--trajfile', action="append", help='trajectory files')
parser.add_argument('--trajdir', action="append", help='directories containing trajectories')
parser.add_argument('--recursive', action='store_true', help='search directories recursively')
parser.add_argument('--pattern', default="*.prod.*.xtc", help='search trajectory pattern')
parser.add_argument('--resname', help='selected residue')
parser.add_argument('--atomname', default=None, help='selected atom name')
parser.add_argument('--symmetry', action='store_true', help='use symmetry')
parser.add_argument('--nbins', default=51, type=int, help='number of bins')
parser.add_argument('--pbc', action='store_true', help='consider PBC')
parser.add_argument('--com', action='store_true', help='use center of mass')
parser.add_argument('--thread', action='store_true', help='use thread')
parser.add_argument('--center', default='box', help='center group (choices: box or residue name)')
args = parser.parse_args()
if not args.resname:
print("residue name need to be supplied; exit\n")
parser.print_help(sys.stderr)
sys.exit()
trajfiles = []
if args.trajfile:
for trajfile in args.trajfile:
xtcfile = Path(trajfile)
if not xtcfile.exists():
print('XTC file', xtcfile, 'not found; skipped', file=sys.stderr)
continue
prefix = re.match(r'(.*?\.prod\.[0-9]+)', xtcfile.name).group(1)
if xtcfile.name.find('nohyd') > 0:
topname = "%s.rec.nohyd.pdb" % prefix
else:
topname = "%s.rec.pdb" % prefix
topfile = xtcfile.parent/topname
if not topfile.exists():
print('PDB file', topfile, 'not found; skipped', file=sys.stderr)
continue
trajfiles.append((xtcfile, topfile))
elif args.trajdir:
for trajdir in args.trajdir:
if args.recursive:
xtcfiles = Path(trajdir).glob('**/*.prod.*.xtc')
else:
xtcfiles = Path(trajdir).glob('*.prod.*.xtc')
for xtcfile in xtcfiles:
prefix = re.match(r'(.*?\.prod\.[0-9]+)', xtcfile.name).group(1)
if xtcfile.name.find('nohyd') > 0:
topname = "%s.rec.nohyd.pdb" % prefix
else:
topname = "%s.rec.pdb" % prefix
topfile = xtcfile.parent/topname
if not topfile.exists():
print('PDB file', topfile, 'not found; skipped', file=sys.stderr)
continue
trajfiles.append((xtcfile, topfile))
else:
parser.print_help(sys.stderr)
sys.exit(1)
compute_density(trajfiles, args.resname, args.atomname, args.symmetry, args.center, args.nbins, args.pbc, args.com, args.thread)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment