Skip to content

Instantly share code, notes, and snippets.

@bougui505
Created October 1, 2020 07:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bougui505/3079d55110a68e1a319ab26fe64f94fd to your computer and use it in GitHub Desktop.
Save bougui505/3079d55110a68e1a319ab26fe64f94fd to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: UTF8 -*-
# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-09-29 15:48:44 (UTC+0200)
import pymol.cmd as cmd
import torch
import sys
def print_progress(instr):
sys.stdout.write(f'{instr}\r')
sys.stdout.flush()
def get_cmap(coords, device, threshold=8.):
pdist = torch.cdist(coords, coords)
S = torch.nn.Sigmoid()
cmap = S(threshold - pdist)
cmap = cmap.to(device)
return cmap
def get_coords(pdbfilename, object, device, selection=None):
if selection is None:
selection = f'{object} and name CA'
cmd.load(pdbfilename, object=object)
cmd.remove(f'(not name CA) and {object}')
coords = cmd.get_coords(selection=selection)
coords = torch.from_numpy(coords)
coords = coords.to(device)
return coords
def permute(coords, weights):
out = coords.t().mm(weights).t()
# out = coords.t().mm(torch.nn.functional.softmax(weights, dim=1)).t()
return out
def build_rotation_matrix(alpha_beta_gamma, device):
alpha, beta, gamma = alpha_beta_gamma
tensor_0 = torch.zeros(1, device=device)
tensor_1 = torch.ones(1, device=device)
alpha = torch.ones(1, requires_grad=True, device=device) * alpha
beta = torch.ones(1, requires_grad=True, device=device) * beta
gamma = torch.ones(1, requires_grad=True, device=device) * gamma
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]),
torch.stack([tensor_0, torch.cos(alpha), -torch.sin(alpha)]),
torch.stack([tensor_0, torch.sin(alpha), torch.cos(alpha)])]).reshape(3, 3)
RY = torch.stack([torch.stack([torch.cos(beta), tensor_0, torch.sin(beta)]),
torch.stack([tensor_0, tensor_1, tensor_0]),
torch.stack([-torch.sin(beta), tensor_0, torch.cos(beta)])]).reshape(3, 3)
RZ = torch.stack([torch.stack([torch.cos(gamma), -torch.sin(gamma), tensor_0]),
torch.stack([torch.sin(gamma), torch.cos(gamma), tensor_0]),
torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
R = RZ.mm(RY).mm(RX)
return R
def build_reflection_matrix(abc, device):
# See: https://en.wikipedia.org/w/index.php?title=Transformation_matrix&oldid=976277111#Reflection
a, b, c = abc
A = torch.tensor([[1. - 2 * a**2, -2 * a * b, -2 * a * c],
[-2 * a * b, 1. - 2 * b**2, -2 * b * c],
[-2 * a * c, -2 * b * c, 1. - 2 * c**2]], device=device)
return A
def transform(coords, T, device):
"""
"""
coords_transform = coords - coords.mean()
coords_transform = coords.mm(T)
return coords_transform
def minsum(v, axis=1, n=2., eps=1e-6):
"""
A sum over v that returns a value close to the minima
"""
w = (1 / (1 / (v + eps) ** n).sum(axis=axis)) ** (1. / n)
return w
def anchor_loss(coords, anchors):
cdist = torch.cdist(coords - coords.mean(axis=0), anchors - anchors.mean(axis=0))
mindists = torch.min(cdist, axis=1)[0]
# mindists = minsum(cdist, axis=1)
loss = (mindists**2).mean()
return loss
def cmap_loss(cmap_pred, cmap_true, w0=0.05):
cmap_pred = cmap_pred.flatten()
cmap_true = cmap_true.flatten()
bceloss = torch.nn.BCELoss(weight=(cmap_true + w0 * torch.ones_like(cmap_true)))
# bceloss = torch.nn.BCELoss(weight=cmap_true)
output = bceloss(cmap_pred, cmap_true)
return output
def align_structures(coords, coords_ref, device, n_iter):
alpha_beta_gamma = torch.randn(3, requires_grad=True, device=device)
abc = torch.randn(3, device=device, requires_grad=True)
optimizer = torch.optim.Adam([alpha_beta_gamma, abc], lr=1e-3)
for t in range(n_iter):
optimizer.zero_grad()
R = build_rotation_matrix(alpha_beta_gamma, device)
A = build_reflection_matrix(abc, device)
T = A.mm(R)
coords_out = transform(coords, T, device)
loss = anchor_loss(coords_out, coords_ref)
loss.backward()
optimizer.step()
if t % 100 == 99:
print_progress(f'{t+1}/{n_iter}: {loss}')
sys.stdout.write('\n')
return transform(coords, R, device='cpu')
def icp(coords, coords_ref, device, n_iter):
"""
Iterative Closest Point
"""
for t in range(n_iter):
cdist = torch.cdist(coords - coords.mean(axis=0),
coords_ref - coords_ref.mean(axis=0))
mindists, argmins = torch.min(cdist, axis=1)
X, _ = torch.lstsq(coords_ref[argmins], coords)
coords = coords.mm(X[:3])
rmsd = torch.sqrt((X[3:]**2).sum(axis=1).mean())
print_progress(f'{t+1}/{n_iter}: {rmsd}')
return coords
def minimize(coords, cmap_ref, device, n_iter):
n = coords.shape[0]
# Permutation matrix
P = torch.eye(n, requires_grad=True, device=device)
optimizer_P = torch.optim.Adam([P, ], lr=1e-3)
for t in range(n_iter):
optimizer_P.zero_grad()
coords_pred = permute(coords, P)
cmap_pred = get_cmap(coords_pred, device=device)
loss_P = cmap_loss(cmap_pred, cmap_ref)
loss_P.backward()
optimizer_P.step()
if t % 100 == 99:
print_progress(f'{t+1}/{n_iter}: {loss_P}')
sys.stdout.write('\n')
return permute(coords, P)
if __name__ == '__main__':
import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
coords_ref = get_coords('5v6p_.pdb', 'ref', device=device)
cmap_ref = get_cmap(coords_ref, device=device)
# cmap_ref[cmap_ref < 0.5] = 0.
# cmap_ref[cmap_ref >= 0.5] = 1.
coords_in = get_coords('map_to_model_5v6p_8637_.pdb', 'mod', device)
cmap_in = get_cmap(coords_in, device='cpu')
n = coords_in.shape[0]
coords_out = minimize(coords_in, cmap_ref, device, 10000)
cmap_out = get_cmap(coords_out, device='cpu').detach().numpy()
coords_out = coords_out.cpu().detach().numpy()
cmd.load_coords(coords_out, 'mod')
cmd.save('out.pdb', selection='mod')
plt.matshow(cmap_in.cpu().numpy())
plt.savefig('cmap_in.png')
plt.matshow(cmap_ref.cpu().numpy())
plt.savefig('cmap_ref.png')
plt.matshow(cmap_out)
plt.savefig('cmap_out.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment