Skip to content

Instantly share code, notes, and snippets.

@leelasd
Created February 13, 2019 04:07
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save leelasd/e04480424c8499423feb20425dda17a0 to your computer and use it in GitHub Desktop.
Clustering Trajectory with msmbuilder and mdtraj
import numpy as np
import mdtraj as md
import matplotlib
matplotlib.use('Agg')
import random
import matplotlib.pyplot as plt
import pandas as pd
import argparse
d = """
===================================================================
Plots Ligand CoM position projected onto HIV-1 NNRT binding site
beta-sheet plane of C-alpha atoms: Ser105, Ile179, Asp186 relative
to the CoM of the NNRTI pocket.
Also shows positions of reported channels:
-Entrance: C-alpha CoM of Lys103, Glu138 (of PDB chain B), Val179
-Tunnel: Tyr181, Tyr188, Trp229, Phe227
-Groove: Val106, Pro225, Pro236
Residues of PDB 4KO0 chain A, unless stated otherwise.
===================================================================
"""
parser = argparse.ArgumentParser()
parser.add_argument("-top", type=str, default="complex.psf", help="topology file")
parser.add_argument("-xyz", type=str, default="4KO0_metad_gbis.dcd", help="traj coordinate file")
parser.add_argument("-stride", type=int, default=1, help="Stride")
parser.add_argument("-cutoff", type=float, default=12, help="Cutoff for COM Distance")
args = parser.parse_args()
from msmbuilder.featurizer import LigandContactFeaturizer
from msmbuilder.cluster import LandmarkAgglomerative
def GetCentroid(traj):
atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != 'H']
distances = np.empty((traj.n_frames, traj.n_frames))
for i in range(traj.n_frames):
distances[i] = md.rmsd(traj, traj, i, atom_indices=atom_indices)
beta = 1
index = np.exp(-beta*distances / distances.std()).sum(axis=1).argmax()
print(index)
centroid = traj[index]
#print(centroid)
return centroid
def ClusterBasedOnLigandContacts(traj,data,cutoff):
feat = LigandContactFeaturizer(reference_frame=traj[0], binding_pocket=1)
df = pd.DataFrame(feat.describe_features(traj[0]))
pocket_contacts = feat.transform(traj)
clusterer = LandmarkAgglomerative(n_clusters=int(cutoff/0.5),n_landmarks=None,linkage='ward', metric='euclidean')
clusterer.fit(pocket_contacts)
fit_assignments = clusterer.landmark_labels_
predict_assignments = [p[0] for p in clusterer.predict(pocket_contacts)]
print(data.describe())
print(len(predict_assignments))
data['Label'] = np.array(predict_assignments)
data = data.sort_values(['TIME','Label'])
cluster_cents = []
for l in data.Label.unique():
smtraj = traj[data[data.Label==l].index]
snap=GetCentroid(smtraj)
cluster_cents.append(snap)
all_snaps = md.join(cluster_cents,check_topology=False)
#picks = []
# for p in data.Label.unique(): picks.append(random.choice(data[data.Label==p].index))
print('Saving Snap shots')
all_snaps.save_pdb('snaps.pdb')
return (data, all_snaps)
def PlotPMF(step_num):
pmf = pd.read_csv('4KO0_135_metad.%d.pmf'%int(step_num),delim_whitespace=True,header=None,comment='#')
pmf.columns = ['DIST','ENERGY']
#df.ENERGY = -1.0*df.ENERGY
fig = plt.figure(figsize=(15,10),tight_layout=True)
plt.plot(pmf.DIST,pmf.ENERGY,color='k',label='PMF',linewidth=3,linestyle='-.')
plt.title('Unbinding Pathway for JLJ 135 from HIV-RT site',fontsize=25)
plt.xlabel('COM Distance CV (Angstroms)',size=20)
plt.ylabel('Energy (Kcal/mol)',size=20)
plt.xlim(-0.5,15.5)
plt.legend(fontsize=20)
plt.grid()
plt.tick_params(axis='both', which='major', labelsize=20)
plt.tick_params(axis='both', which='minor', labelsize=20)
plt.savefig('DistVsBias.png',dpi=300)
return None
def projections_on_betasheet(traj, CoM_xyz_lig, CoM_xyz_pocket):
"""
Assumes global definitions of entrance, tunnel and groove
"""
XYZ = []
XYZ_e, XYZ_t, XYZ_g = [], [], []
CoM_xyz_entrance = np.average(traj.atom_slice(entrance_CA).xyz, axis=1)
CoM_xyz_tunnel = np.average(traj.atom_slice(tunnel_CA).xyz, axis=1)
CoM_xyz_groove = np.average(traj.atom_slice(groove_CA).xyz, axis=1)
for t in np.arange(len(CoM_xyz_lig)):
# ligand relative to the center
xyz = CoM_xyz_lig[t] - CoM_xyz_pocket[t]
# exit channels
xyz_e = CoM_xyz_entrance[t] - CoM_xyz_pocket[t]
xyz_t = CoM_xyz_tunnel[t] - CoM_xyz_pocket[t]
xyz_g = CoM_xyz_groove[t] - CoM_xyz_pocket[t]
# projected on beta sheets
# generate orthonormal unit vectors:
# u1
u1 = xyz_t - xyz_e
u1 /= np.linalg.norm(u1)
# print(u1)
u2 = xyz_g - xyz_e
u2 -= u1 * np.dot(u1, u2) / np.dot(u1, u1)
u2 /= np.linalg.norm(u2)
# print(u2)
u3 = np.cross(u1, u2)
# print(u3)
XYZ.append([np.dot(u1, xyz),
np.dot(u2, xyz),
np.dot(u3, xyz)])
XYZ_e.append([np.dot(u1, xyz_e), np.dot(u2, xyz_e), np.dot(u3, xyz_e)])
XYZ_t.append([np.dot(u1, xyz_t), np.dot(u2, xyz_t), np.dot(u3, xyz_t)])
XYZ_g.append([np.dot(u1, xyz_g), np.dot(u2, xyz_g), np.dot(u3, xyz_g)])
XYZ = np.array(XYZ)
XYZ_e, XYZ_t, XYZ_g = np.array(XYZ_e), np.array(XYZ_t), np.array(XYZ_g)
return XYZ, XYZ_e, XYZ_t, XYZ_g
def plot_projections(XYZ, XYZ_e, XYZ_t, XYZ_g,Time):
# Make plot with vertical (default) colorbar
fig, ax = plt.subplots(figsize=(10,10))
cax = plt.scatter(XYZ[:,0]*10, XYZ[:,1]*10, \
c=np.linspace(0,1,len(XYZ)), cmap=plt.cm.RdYlBu_r, alpha=0.5)
cbar = fig.colorbar(cax, ticks=np.linspace(0,1,len(XYZ))[::100])
cbar.ax.set_yticklabels(Time[::100].astype(int))
cbar.set_label("time / ps", fontsize=22)
#start
plt.scatter(XYZ[0,0]*10, XYZ[0,1]*10, marker="x", s=100, c="red")
# plt.scatter(XYZ_e[:,0]*10, XYZ_e[:,1]*10, s=300000, c="green", alpha=0.25)
plt.scatter(XYZ_e[:,0]*10, XYZ_e[:,1]*10, c="green", alpha=0.25,s=25)
plt.text(np.average(XYZ_e[:,0])*10, np.average(XYZ_e[:,1])*10, "Entrance", fontsize=20, color="black")
plt.scatter(XYZ_g[:,0]*10, XYZ_g[:,1]*10, c="orange", alpha=0.25,s=25)
plt.text(np.average(XYZ_g[:,0])*10, np.average(XYZ_g[:,1])*10, "Groove", fontsize=20, color="black")
plt.scatter(XYZ_t[:,0]*10, XYZ_t[:,1]*10, c="magenta", alpha=0.25,s=25)
plt.text(np.average(XYZ_t[:,0])*10, np.average(XYZ_t[:,1])*10, "Tunnel", fontsize=20, color="black")
plt.hlines(0.0, -20, 20, linestyles="--", colors="black")
plt.xlim(-20,10)
plt.xlabel(r"$proj_x$ / $\AA$", fontsize=25)
plt.vlines(0.0, -20, 20, linestyles="--", colors="black")
plt.ylabel(r"$proj_y$ / $\AA$", fontsize=25)
plt.ylim(-20,10)
plt.tick_params(axis='both', which='major', labelsize=20)
plt.tick_params(axis='both', which='minor', labelsize=20)
plt.tight_layout()
plt.savefig("ExitPath.png")
# plt.show()
####################
#traj = md.load(args.xyz, top=args.top, stride=1)
trajs = []
chunk_num = 0
for smtraj in md.iterload(args.xyz, top=args.top, chunk=500,stride=args.stride):
smtraj = smtraj.remove_solvent()
smtraj = smtraj.superpose(smtraj,0,atom_indices=smtraj.top.select('protein'))
trajs.append(smtraj)
print('chunk %d of size 500'%chunk_num)
chunk_num = chunk_num+ 1
print('Linking Chunks')
traj = md.join(trajs,check_topology=False)
top = traj.topology
traj.superpose(traj,0,atom_indices=top.select('protein'))
#traj.save_hdf5('plig_p1.h5')
#traj[0].save_pdb('top_plig.pdb')
#import sys
#sys.exit()
top = traj.top
lig = list(top.select("mass >2 and chainid 2"))
#entrance_CA = [2191, 8786, 9980]
entrance_CA = list(top.select("name CA and ((residue 103 and chainid 1) or (residue 134 and chainid 0) or (residue 181 and chainid 1) or (residue 105 and chainid 1))"))
groove_CA = list(top.select("name CA and ((residue 320 or residue 227 or residue 229 or residue 228 or residue 238 and residue 108) and chainid 1)"))
tunnel_CA = list(top.select("name CA and (residue 231 or residue 190 or residue 183) and chainid 1"))
pocket_atoms_CA = entrance_CA+groove_CA+tunnel_CA
## assumes stride of 10 reported evry 500 steps at 2fs per timestep
#Time = traj.time*500*args.stride*0.002 # 1ps
Time = np.array([i*10.0 for i in range(len(traj))])
CoM_xyz_lig = np.average(traj.atom_slice(lig).xyz, axis=1)
CoM_xyz_pocket = np.average(traj.atom_slice(pocket_atoms_CA).xyz, axis=1)
CoM_dist = np.linalg.norm(CoM_xyz_pocket-CoM_xyz_lig, axis=1) *10.0
try: idx_1st_unbound = np.where(CoM_dist > args.cutoff)[0][0]
except IndexError: idx_1st_unbound = -1 # not yet unbound
print(idx_1st_unbound,traj.time[idx_1st_unbound],int(traj.time[idx_1st_unbound]*500))
XYZ, XYZ_e, XYZ_t, XYZ_g = projections_on_betasheet(traj, CoM_xyz_lig, CoM_xyz_pocket)
plot_projections(XYZ[:idx_1st_unbound], XYZ_e[:idx_1st_unbound], XYZ_t[:idx_1st_unbound], XYZ_g[:idx_1st_unbound],Time[:idx_1st_unbound])
if idx_1st_unbound >0:
fin_num_step = traj.time[idx_1st_unbound]*500
traj=traj[:idx_1st_unbound]
data = pd.DataFrame({'TIME':Time[:idx_1st_unbound],'COM_DIST':CoM_dist[:idx_1st_unbound]})
data.to_csv('CLUSTERING_RES.csv',index=False,float_format='%3.3f')
final_data,snaps=ClusterBasedOnLigandContacts(traj,data,args.cutoff)
# PlotPMF(step_num=fin_num_step)
traj.save_hdf5('plig_cutoff.h5')
else:
print('Simple MD SImulation')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment