Created
February 13, 2019 04:07
Star
You must be signed in to star a gist
Clustering Trajectory with msmbuilder and mdtraj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import mdtraj as md | |
import matplotlib | |
matplotlib.use('Agg') | |
import random | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import argparse | |
d = """ | |
=================================================================== | |
Plots Ligand CoM position projected onto HIV-1 NNRT binding site | |
beta-sheet plane of C-alpha atoms: Ser105, Ile179, Asp186 relative | |
to the CoM of the NNRTI pocket. | |
Also shows positions of reported channels: | |
-Entrance: C-alpha CoM of Lys103, Glu138 (of PDB chain B), Val179 | |
-Tunnel: Tyr181, Tyr188, Trp229, Phe227 | |
-Groove: Val106, Pro225, Pro236 | |
Residues of PDB 4KO0 chain A, unless stated otherwise. | |
=================================================================== | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-top", type=str, default="complex.psf", help="topology file") | |
parser.add_argument("-xyz", type=str, default="4KO0_metad_gbis.dcd", help="traj coordinate file") | |
parser.add_argument("-stride", type=int, default=1, help="Stride") | |
parser.add_argument("-cutoff", type=float, default=12, help="Cutoff for COM Distance") | |
args = parser.parse_args() | |
from msmbuilder.featurizer import LigandContactFeaturizer | |
from msmbuilder.cluster import LandmarkAgglomerative | |
def GetCentroid(traj): | |
atom_indices = [a.index for a in traj.topology.atoms if a.element.symbol != 'H'] | |
distances = np.empty((traj.n_frames, traj.n_frames)) | |
for i in range(traj.n_frames): | |
distances[i] = md.rmsd(traj, traj, i, atom_indices=atom_indices) | |
beta = 1 | |
index = np.exp(-beta*distances / distances.std()).sum(axis=1).argmax() | |
print(index) | |
centroid = traj[index] | |
#print(centroid) | |
return centroid | |
def ClusterBasedOnLigandContacts(traj,data,cutoff): | |
feat = LigandContactFeaturizer(reference_frame=traj[0], binding_pocket=1) | |
df = pd.DataFrame(feat.describe_features(traj[0])) | |
pocket_contacts = feat.transform(traj) | |
clusterer = LandmarkAgglomerative(n_clusters=int(cutoff/0.5),n_landmarks=None,linkage='ward', metric='euclidean') | |
clusterer.fit(pocket_contacts) | |
fit_assignments = clusterer.landmark_labels_ | |
predict_assignments = [p[0] for p in clusterer.predict(pocket_contacts)] | |
print(data.describe()) | |
print(len(predict_assignments)) | |
data['Label'] = np.array(predict_assignments) | |
data = data.sort_values(['TIME','Label']) | |
cluster_cents = [] | |
for l in data.Label.unique(): | |
smtraj = traj[data[data.Label==l].index] | |
snap=GetCentroid(smtraj) | |
cluster_cents.append(snap) | |
all_snaps = md.join(cluster_cents,check_topology=False) | |
#picks = [] | |
# for p in data.Label.unique(): picks.append(random.choice(data[data.Label==p].index)) | |
print('Saving Snap shots') | |
all_snaps.save_pdb('snaps.pdb') | |
return (data, all_snaps) | |
def PlotPMF(step_num): | |
pmf = pd.read_csv('4KO0_135_metad.%d.pmf'%int(step_num),delim_whitespace=True,header=None,comment='#') | |
pmf.columns = ['DIST','ENERGY'] | |
#df.ENERGY = -1.0*df.ENERGY | |
fig = plt.figure(figsize=(15,10),tight_layout=True) | |
plt.plot(pmf.DIST,pmf.ENERGY,color='k',label='PMF',linewidth=3,linestyle='-.') | |
plt.title('Unbinding Pathway for JLJ 135 from HIV-RT site',fontsize=25) | |
plt.xlabel('COM Distance CV (Angstroms)',size=20) | |
plt.ylabel('Energy (Kcal/mol)',size=20) | |
plt.xlim(-0.5,15.5) | |
plt.legend(fontsize=20) | |
plt.grid() | |
plt.tick_params(axis='both', which='major', labelsize=20) | |
plt.tick_params(axis='both', which='minor', labelsize=20) | |
plt.savefig('DistVsBias.png',dpi=300) | |
return None | |
def projections_on_betasheet(traj, CoM_xyz_lig, CoM_xyz_pocket): | |
""" | |
Assumes global definitions of entrance, tunnel and groove | |
""" | |
XYZ = [] | |
XYZ_e, XYZ_t, XYZ_g = [], [], [] | |
CoM_xyz_entrance = np.average(traj.atom_slice(entrance_CA).xyz, axis=1) | |
CoM_xyz_tunnel = np.average(traj.atom_slice(tunnel_CA).xyz, axis=1) | |
CoM_xyz_groove = np.average(traj.atom_slice(groove_CA).xyz, axis=1) | |
for t in np.arange(len(CoM_xyz_lig)): | |
# ligand relative to the center | |
xyz = CoM_xyz_lig[t] - CoM_xyz_pocket[t] | |
# exit channels | |
xyz_e = CoM_xyz_entrance[t] - CoM_xyz_pocket[t] | |
xyz_t = CoM_xyz_tunnel[t] - CoM_xyz_pocket[t] | |
xyz_g = CoM_xyz_groove[t] - CoM_xyz_pocket[t] | |
# projected on beta sheets | |
# generate orthonormal unit vectors: | |
# u1 | |
u1 = xyz_t - xyz_e | |
u1 /= np.linalg.norm(u1) | |
# print(u1) | |
u2 = xyz_g - xyz_e | |
u2 -= u1 * np.dot(u1, u2) / np.dot(u1, u1) | |
u2 /= np.linalg.norm(u2) | |
# print(u2) | |
u3 = np.cross(u1, u2) | |
# print(u3) | |
XYZ.append([np.dot(u1, xyz), | |
np.dot(u2, xyz), | |
np.dot(u3, xyz)]) | |
XYZ_e.append([np.dot(u1, xyz_e), np.dot(u2, xyz_e), np.dot(u3, xyz_e)]) | |
XYZ_t.append([np.dot(u1, xyz_t), np.dot(u2, xyz_t), np.dot(u3, xyz_t)]) | |
XYZ_g.append([np.dot(u1, xyz_g), np.dot(u2, xyz_g), np.dot(u3, xyz_g)]) | |
XYZ = np.array(XYZ) | |
XYZ_e, XYZ_t, XYZ_g = np.array(XYZ_e), np.array(XYZ_t), np.array(XYZ_g) | |
return XYZ, XYZ_e, XYZ_t, XYZ_g | |
def plot_projections(XYZ, XYZ_e, XYZ_t, XYZ_g,Time): | |
# Make plot with vertical (default) colorbar | |
fig, ax = plt.subplots(figsize=(10,10)) | |
cax = plt.scatter(XYZ[:,0]*10, XYZ[:,1]*10, \ | |
c=np.linspace(0,1,len(XYZ)), cmap=plt.cm.RdYlBu_r, alpha=0.5) | |
cbar = fig.colorbar(cax, ticks=np.linspace(0,1,len(XYZ))[::100]) | |
cbar.ax.set_yticklabels(Time[::100].astype(int)) | |
cbar.set_label("time / ps", fontsize=22) | |
#start | |
plt.scatter(XYZ[0,0]*10, XYZ[0,1]*10, marker="x", s=100, c="red") | |
# plt.scatter(XYZ_e[:,0]*10, XYZ_e[:,1]*10, s=300000, c="green", alpha=0.25) | |
plt.scatter(XYZ_e[:,0]*10, XYZ_e[:,1]*10, c="green", alpha=0.25,s=25) | |
plt.text(np.average(XYZ_e[:,0])*10, np.average(XYZ_e[:,1])*10, "Entrance", fontsize=20, color="black") | |
plt.scatter(XYZ_g[:,0]*10, XYZ_g[:,1]*10, c="orange", alpha=0.25,s=25) | |
plt.text(np.average(XYZ_g[:,0])*10, np.average(XYZ_g[:,1])*10, "Groove", fontsize=20, color="black") | |
plt.scatter(XYZ_t[:,0]*10, XYZ_t[:,1]*10, c="magenta", alpha=0.25,s=25) | |
plt.text(np.average(XYZ_t[:,0])*10, np.average(XYZ_t[:,1])*10, "Tunnel", fontsize=20, color="black") | |
plt.hlines(0.0, -20, 20, linestyles="--", colors="black") | |
plt.xlim(-20,10) | |
plt.xlabel(r"$proj_x$ / $\AA$", fontsize=25) | |
plt.vlines(0.0, -20, 20, linestyles="--", colors="black") | |
plt.ylabel(r"$proj_y$ / $\AA$", fontsize=25) | |
plt.ylim(-20,10) | |
plt.tick_params(axis='both', which='major', labelsize=20) | |
plt.tick_params(axis='both', which='minor', labelsize=20) | |
plt.tight_layout() | |
plt.savefig("ExitPath.png") | |
# plt.show() | |
#################### | |
#traj = md.load(args.xyz, top=args.top, stride=1) | |
trajs = [] | |
chunk_num = 0 | |
for smtraj in md.iterload(args.xyz, top=args.top, chunk=500,stride=args.stride): | |
smtraj = smtraj.remove_solvent() | |
smtraj = smtraj.superpose(smtraj,0,atom_indices=smtraj.top.select('protein')) | |
trajs.append(smtraj) | |
print('chunk %d of size 500'%chunk_num) | |
chunk_num = chunk_num+ 1 | |
print('Linking Chunks') | |
traj = md.join(trajs,check_topology=False) | |
top = traj.topology | |
traj.superpose(traj,0,atom_indices=top.select('protein')) | |
#traj.save_hdf5('plig_p1.h5') | |
#traj[0].save_pdb('top_plig.pdb') | |
#import sys | |
#sys.exit() | |
top = traj.top | |
lig = list(top.select("mass >2 and chainid 2")) | |
#entrance_CA = [2191, 8786, 9980] | |
entrance_CA = list(top.select("name CA and ((residue 103 and chainid 1) or (residue 134 and chainid 0) or (residue 181 and chainid 1) or (residue 105 and chainid 1))")) | |
groove_CA = list(top.select("name CA and ((residue 320 or residue 227 or residue 229 or residue 228 or residue 238 and residue 108) and chainid 1)")) | |
tunnel_CA = list(top.select("name CA and (residue 231 or residue 190 or residue 183) and chainid 1")) | |
pocket_atoms_CA = entrance_CA+groove_CA+tunnel_CA | |
## assumes stride of 10 reported evry 500 steps at 2fs per timestep | |
#Time = traj.time*500*args.stride*0.002 # 1ps | |
Time = np.array([i*10.0 for i in range(len(traj))]) | |
CoM_xyz_lig = np.average(traj.atom_slice(lig).xyz, axis=1) | |
CoM_xyz_pocket = np.average(traj.atom_slice(pocket_atoms_CA).xyz, axis=1) | |
CoM_dist = np.linalg.norm(CoM_xyz_pocket-CoM_xyz_lig, axis=1) *10.0 | |
try: idx_1st_unbound = np.where(CoM_dist > args.cutoff)[0][0] | |
except IndexError: idx_1st_unbound = -1 # not yet unbound | |
print(idx_1st_unbound,traj.time[idx_1st_unbound],int(traj.time[idx_1st_unbound]*500)) | |
XYZ, XYZ_e, XYZ_t, XYZ_g = projections_on_betasheet(traj, CoM_xyz_lig, CoM_xyz_pocket) | |
plot_projections(XYZ[:idx_1st_unbound], XYZ_e[:idx_1st_unbound], XYZ_t[:idx_1st_unbound], XYZ_g[:idx_1st_unbound],Time[:idx_1st_unbound]) | |
if idx_1st_unbound >0: | |
fin_num_step = traj.time[idx_1st_unbound]*500 | |
traj=traj[:idx_1st_unbound] | |
data = pd.DataFrame({'TIME':Time[:idx_1st_unbound],'COM_DIST':CoM_dist[:idx_1st_unbound]}) | |
data.to_csv('CLUSTERING_RES.csv',index=False,float_format='%3.3f') | |
final_data,snaps=ClusterBasedOnLigandContacts(traj,data,args.cutoff) | |
# PlotPMF(step_num=fin_num_step) | |
traj.save_hdf5('plig_cutoff.h5') | |
else: | |
print('Simple MD SImulation') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment