Skip to content

Instantly share code, notes, and snippets.

@fabian-paul
Last active March 9, 2020 23:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fabian-paul/74a3797b4ce1b728ea507bb01bbd56b3 to your computer and use it in GitHub Desktop.
Save fabian-paul/74a3797b4ce1b728ea507bb01bbd56b3 to your computer and use it in GitHub Desktop.
delete all water molecules from MD trajectory, except molecules close to ligand
#!/usr/bin/env python
import numpy as np
import mdtraj
def tl(filename, chunk=100):
with mdtraj.open(filename, mode='r') as fh:
try:
return (len(fh) - 1) // chunk + 1
except:
return None
def strip_most_water(traj, n_water=200, ligand_selection='resname IMAT'):
# TODO: take additional measures to preserve the identity of water molecules if possible
top = traj.top
# get atom indices for the different selections
protein_idx = top.select('protein')
ligand_idx = top.select(ligand_selection)
water_idx = np.array([[a.index for a in r.atoms] for r in top.residues if r.is_water])
assert water_idx.shape[1] == 3
ions_idx = top.select('resname CLA or resname POT')
first_water_idx = water_idx[0:n_water].reshape(-1)
# prepare data structure for distance computation
atom_pairs = []
water_table_idx = []
for ri_ligand, i_ligand in enumerate(ligand_idx):
for ri_wat, wat_mol in enumerate(water_idx):
for ri_wat_atom, i_wat in enumerate(wat_mol):
atom_pairs.append([i_ligand, i_wat])
water_table_idx.append([ri_wat, ri_ligand*3 + ri_wat_atom])
atom_pairs = np.array(atom_pairs)
water_table_idx = np.array(water_table_idx)
# restrict trajectory data to protein, ligand, ions and closest waters
new_xyz = []
all_dist_traj = mdtraj.compute_distances(traj, atom_pairs=atom_pairs)
assert len(traj) == len(all_dist_traj)
for all_dist_frame, frame in zip(all_dist_traj, traj.xyz):
# rearrange result into (water_index x rest) matrix
D = np.zeros(shape=(len(water_idx), len(ligand_idx)*3))
D[water_table_idx[:, 0], water_table_idx[:, 1]] = all_dist_frame
# compute minimal distance
mindist = np.min(D, axis=1)
# get atom coordiantes
i_closest_water_atom_idx = water_idx[np.argsort(mindist)[0:n_water], :].reshape(-1)
new_frame = frame[np.concatenate((protein_idx, ligand_idx, i_closest_water_atom_idx, ions_idx)), :]
new_xyz.append(new_frame)
# store result in mdtraj.Trajectory object
reduced_frame = traj[0].atom_slice(np.concatenate((protein_idx, ligand_idx, first_water_idx, ions_idx)))
reduced_top = reduced_frame.top
new_traj = mdtraj.Trajectory(new_xyz, reduced_top, time=traj.time, unitcell_lengths=traj.unitcell_lengths, unitcell_angles=traj.unitcell_angles)
return new_traj
if __name__=='__main__':
import argparse
import os
from tqdm import tqdm
parser = argparse.ArgumentParser(description='wrap molecules, remove far-away water molecules',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--top', metavar='filename', required=True,
help='(in) file name of the full topology file')
parser.add_argument('--nwat', metavar='number', default=200, type=int,
help='(in) number of water molecules to keep')
parser.add_argument('--output', metavar='trajname', default='(auto)',
help='(out) file name for output trajectory. Choosing xtc format enables full online processing.')
parser.add_argument('--input', metavar='trajname', required=True, nargs='+',
help='(in) file name of input trajectory')
parser.add_argument('--ligand', metavar='selection', default='resname IMAT or resname IMA',
help='(in) MDTraj selection string for the ligand')
args = parser.parse_args()
if args.output == '(auto)':
base = os.path.splitext(args.input[0])[0]
out_name = base + '.small.xtc'
else:
out_name = args.output
top = mdtraj.load_topology(args.top)
if out_name.endswith('.xtc'):
with mdtraj.formats.XTCTrajectoryFile(out_name, 'w') as f: # full online version
for fname_input in args.input:
print('processing trajectory piece', fname_input, 'frame count:', tl(fname_input, chunk=1))
traj = mdtraj.iterload(fname_input, top=top)
for chunk in tqdm(traj, total=tl(fname_input)):
chunk = strip_most_water(chunk, n_water=args.nwat, ligand_selection=args.ligand)
chunk.image_molecules(inplace=True)
f.write(chunk.xyz, time=chunk.time, box=chunk.unitcell_vectors)
else:
out = None
for fname_input in args.input:
print('processing trajectory piece', fname_input, 'frame count:', tl(fname_input, chunk=1))
traj = mdtraj.iterload(fname_input, top=top)
for chunk in tqdm(traj, total=tl(fname_input)):
chunk = strip_most_water(chunk, n_water=args.nwat, ligand_selection=args.ligand)
chunk.image_molecules(inplace=True)
if out is None:
out = chunk
else:
out = out.join(chunk)
out.save(out_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment