Skip to content

Instantly share code, notes, and snippets.

@tdudgeon
Created June 1, 2020 15:09
Show Gist options
  • Save tdudgeon/ee6f441e4828b98d904cc4c5b56ba105 to your computer and use it in GitHub Desktop.
Save tdudgeon/ee6f441e4828b98d904cc4c5b56ba105 to your computer and use it in GitHub Desktop.
Preparation for tethered docking
import argparse, os, sys, json, traceback
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdFMCS
from rdkit.Chem import rdMolTransforms
from rdkit.Chem.rdMolAlign import *
from rdkit.Chem.rdForceFieldHelpers import *
from rdkit import RDLogger
from rdkit import rdBase
from dimorphite_dl import run_with_mol_list
def log(*args, **kwargs):
"""Log output to STDERR
"""
print(*args, file=sys.stderr, ** kwargs)
def enumerate_charges(mol, min_ph, max_ph):
"""
Enumerate the charge states of this mol
:param mol:
:return:
"""
mol_list = [ mol ]
protonated_mols = run_with_mol_list(mol_list, min_ph=min_ph, max_ph=max_ph)
return protonated_mols
chunk_size = 0
writers = {}
molecule_counts = {}
file_counts = {}
def get_writer(id, outfile_base):
global writers
global molecule_counts
global file_counts
if id in writers:
writer = writers[id]
count = molecule_counts[id]
if count <= chunk_size:
#print(' Using existing writer for', id, writer)
return writer
# we need to create a new writer
# re-set the molecule count
molecule_counts[id] = 0
# close the old writer if it exists
if id in writers:
file_count = file_counts[id]
print('Closing writer', id, file_count)
writer.close()
file_count = file_count + 1
else:
file_count = 0
writer = Chem.SDWriter(id + '_' + outfile_base + '_' + f'{file_count:03}' + '.sdf')
writers[id] = writer
file_counts[id] = file_count
#print(' Using new writer for', id, writer)
return writer
def get_h_bonds(atom):
l = []
for b in atom.GetBonds():
o = b.GetOtherAtom(atom)
if o.GetAtomicNum() == 1:
l.append(b.GetIdx())
return tuple(l)
def get_num_h_neighbours(atom):
count = 0
for n in atom.GetNeighbors():
if n.GetAtomicNum() != 1:
count += 1
return len(atom.GetNeighbors()) - count
def count_h_attachments_for_match(mol, match):
res = []
for m in match:
atom = mol.GetAtomWithIdx(m)
res.append(get_num_h_neighbours(atom))
return tuple(res)
def count_h_attachements_for_mol(mol):
molh = Chem.AddHs(mol)
res = []
for atom in molh.GetAtoms():
if atom.GetAtomicNum() != 1:
res.append(get_num_h_neighbours(atom))
return tuple(res)
def count_attachments(mol, match):
res = []
for m in match:
atom = mol.GetAtomWithIdx(m)
res.append(get_h_bonds(atom))
return tuple(res)
num_swap_success = 0
# This is the rdkit ConstrainedEmbed function from AllChem
# Have edited to make it a bit shorter by deleting the instructions
def MultiConstrainedEmbed(mol, core, getForceField=UFFGetMoleculeForceField):
global num_swap_success
print(Chem.MolToSmiles(mol), Chem.MolToSmiles(core))
# In original constrained function, only a single match was found.
# Here we look for multiple matches
matches = mol.GetSubstructMatches(core, uniquify=False)
if not matches:
raise ValueError("molecule doesn't match the core")
print(' found', len(matches), 'SSS matches')
mols_matches = []
hs_for_core = count_h_attachements_for_mol(core)
print(' COREHS', hs_for_core)
# loop over our matches
nhas = []
for match in matches:
nha = count_h_attachments_for_match(mol, match)
atoms_to_switch = []
if nha in nhas:
print(' NHA', nha, 'SSS', match, 'EXISTING')
continue
else:
nhas.append(nha)
for idx in range(len(hs_for_core)):
hs_on_core_atoms = hs_for_core[idx]
hs_on_matched_atoms = nha[idx]
# check if the number of H atoms is different e.g. the atom is substituted compared to the core
if hs_on_core_atoms != hs_on_matched_atoms:
#print(' ',idx, hs_on_core_atoms, hs_on_matched_atoms,c)
# Do we have at least one H on the atom - if so there is potential for swapping the substitutions
if hs_on_matched_atoms > 0:
atomidx = match[idx]
atom = mol.GetAtomWithIdx(atomidx)
free_atoms = []
free_symbols = []
tethered_atoms = []
tethered_symbols = []
for other_atom in atom.GetNeighbors():
o_idx = other_atom.GetIdx()
o_sym = other_atom.GetSymbol() + str(other_atom.GetIdx())
if other_atom.GetIdx() in match:
#print(' tethered', o_sym)
tethered_atoms.append(o_idx)
tethered_symbols.append(o_sym)
else:
#print(' free', o_sym)
free_atoms.append(o_idx)
free_symbols.append(o_sym)
print(' Atom:', atom.GetSymbol() + str(atomidx), 'Free:', free_symbols, 'Tethered:', tethered_symbols)
# Include if there are 2 or more tethered atoms
# - otherwise it's free to rotate
if len(tethered_atoms) > 1:
print(' GT1')
atoms_to_switch.append((idx, atomidx, hs_on_matched_atoms, tuple(free_atoms), tuple(tethered_atoms)))
# Handle the special case of where there is only one tether but the bond is not rotatable
# e.g. a double bond, but not a methyl group where the 3 Hs are equivalent
if len(tethered_atoms) == 1:
print(' EQ1')
bond = mol.GetBondBetweenAtoms(match[idx], tethered_atoms[0])
if bond: # should always be one, but just in case
# Don't know how to ask if a bond is rotatable.
# So instead ask if it is a double bond which won't be rotatable
if bond.GetBondType() == Chem.BondType.DOUBLE:
print(' NON-ROTATABLE', bond.GetBondType(), atom.GetSymbol() + '->' + bond.GetOtherAtom(atom).GetSymbol())
atoms_to_switch.append((idx, atomidx, match[idx], hs_on_matched_atoms, tuple(free_atoms), tuple(tethered_atoms)))
print(' NHA', nha, 'SSS', match, 'NEW', atoms_to_switch)
new_mol = Chem.RWMol(mol)
coordMap={}
coreConf = core.GetConformer(-1)
for i, idxI in enumerate(match):
coordMap[idxI] = coreConf.GetAtomPosition(i)
ci = AllChem.EmbedMolecule(new_mol, coordMap=coordMap)
if ci < 0:
raise ValueError('Could not embed molecule.')
algMap=[(j,i) for i,j in enumerate(match)]
# rotate the embedded conformation onto the core:
rms = AlignMol(new_mol, core, atomMap=algMap)
# process the original embedded molecule
minimize_mol(new_mol, core, match, algMap, getForceField)
mols_matches.append((new_mol, match))
# If there are atoms that need swapping then do so
for ate in atoms_to_switch:
if len(ate[3]) == 2: # we'll ignore others cases for now. 5 valent atoms?
m2 = Chem.RWMol(new_mol)
conf = m2.GetConformer(-1)
m2.SetProp('Status', 'before')
angle = rdMolTransforms.GetAngleDeg(conf, ate[3][0], ate[1], ate[3][1])
print(' Swapping atoms',
m2.GetAtomWithIdx(ate[3][0]).GetSymbol() + str(ate[3][0]), 'and',
m2.GetAtomWithIdx(ate[3][1]).GetSymbol() + str(ate[3][1]), 'on atom',
m2.GetAtomWithIdx(ate[1]).GetSymbol() + str(ate[1]),
'using angle', angle)
# find a free H atom
if m2.GetAtomWithIdx(ate[3][0]).GetAtomicNum() == 1:
a = ate[3][0]
b = ate[3][1]
elif m2.GetAtomWithIdx(ate[3][1]).GetAtomicNum() == 1:
a = ate[3][1]
b = ate[3][0]
else:
# both are heavy atoms - don't know how to handle these yet so we give up
continue
# grab the position of the heavy atom and the bond length of the H atom
pos = conf.GetAtomPosition(b)
hBondLength = rdMolTransforms.GetBondLength(conf, atomidx, a)
rdMolTransforms.SetAngleDeg(conf, a, ate[1], b, 0.0)
m2.SetProp('Status', 'during')
print(' angle is now', rdMolTransforms.GetAngleDeg(conf, a, ate[1], b))
conf.SetAtomPosition(a, pos)
rdMolTransforms.SetBondLength(conf, ate[1], a, hBondLength)
print(' angle is now', rdMolTransforms.GetAngleDeg(conf, a, ate[1], b))
m2.SetProp('Status', 'flipped')
minimize_mol(m2, core, match, algMap, getForceField)
m2.SetProp('Status', 'minimized')
mols_matches.append((m2, match))
num_swap_success += 1
else:
print(' WARNING: more than 2 atoms to switch')
# Return a list of tuples of (mol, match)
return mols_matches
def minimize_mol(mol, core, match, algMap, getForceField):
ff = getForceField(mol, confId=0)
conf = core.GetConformer()
for i in range(core.GetNumAtoms()):
p = conf.GetAtomPosition(i)
pIdx = ff.AddExtraPoint(p.x, p.y, p.z, fixed=True) - 1
ff.AddDistanceConstraint(pIdx, match[i], 0, 0, 100.)
ff.Initialize()
n = 4
more = ff.Minimize(energyTol=1e-4, forceTol=1e-3)
while more and n:
more = ff.Minimize(energyTol=1e-4, forceTol=1e-3)
n -= 1
# realign
rms = AlignMol(mol, core, atomMap=algMap)
mol.SetDoubleProp('EmbedRMS', rms)
def execute(candidates_json, outfile_base, min_ph=None, max_ph=None, max_inputs=0, max_outputs=0, modulus=0):
global molecule_counts
GetFF=lambda x,confId=-1:AllChem.MMFFGetMoleculeForceField(x,AllChem.MMFFGetMoleculeProperties(x),confId=confId)
with open(candidates_json, 'r') as f:
candidates = json.load(f)
queries = candidates['queries']['molecules']
results = candidates['results']
print(len(queries), len(results))
num_mols = 0
num_processed = 0
num_outputs = 0
num_errors = 0
hits = {}
for query in queries:
id = query['id']
molfile = query['originalMol']
# print(molfile)
hit = Chem.MolFromMolBlock(molfile)
print("Examining hit", id, Chem.MolToSmiles(hit))
hits[id] = hit
for result in results:
if 0 < modulus and num_mols % modulus != 0:
num_mols += 1
continue
if 0 < max_inputs <= num_mols:
break
if 0 < max_outputs <= num_processed:
break
ids = []
for id in result['sourceMols']:
w = get_writer(id, outfile_base)
ids.append(id)
smiles = result['smiles']
hit = hits[id]
print('Processing', smiles)
try:
mol = Chem.MolFromSmiles(smiles)
if min_ph and max_ph:
enumerated_mols = enumerate_charges(mol, min_ph, max_ph)
print(' Enumerated', len(enumerated_mols), [Chem.MolToSmiles(x) for x in enumerated_mols])
else:
enumerated_mols = [mol]
mcs0 = rdFMCS.FindMCS([hit, mol], completeRingsOnly=True, matchValences=False,
bondCompare=rdFMCS.BondCompare.CompareOrderExact)
replaced = AllChem.ReplaceSidechains(hit, mcs0.queryMol)
core = AllChem.DeleteSubstructs(replaced, Chem.MolFromSmiles('*'))
core.UpdatePropertyCache()
if id in molecule_counts:
current_mol_count = molecule_counts[id]
else:
current_mol_count = 0
for ligand in enumerated_mols:
molh = Chem.AddHs(ligand)
# mol_match_tuple = MultiConstrainedEmbed(molh, core, getForceField=GetFF)
mol_match_tuple = MultiConstrainedEmbed(molh, core)
print(' ', len(mol_match_tuple), 'mols tethered to', hit.GetProp('_Name'))
for t_mol, match in mol_match_tuple:
t_mol.SetProp('_Name', smiles)
t_mol.SetProp('Hit', id)
tethers = ','.join([str(x + 1) for x in match])
t_mol.SetProp('TETHERED ATOMS', tethers)
print(' Tethers: ', tethers)
w.write(t_mol)
num_outputs += 1
current_mol_count += len(enumerated_mols)
molecule_counts[id] = current_mol_count
except Exception as e:
num_errors += 1
print(' Error: ', smiles, e)
traceback.print_exc()
print(' processed mol', num_processed, 'Targets:', ','.join(ids))
num_processed += 1
num_mols += 1
for w in writers.values():
w.close()
print('Totals - count:', num_processed, 'total:', num_outputs, 'errors:', num_errors)
def main():
"""
Example usage:
python scripts/transfs/prepare-tether.py -i targets/mpro/expts/transfs/expanded-22_hac0-10_rac0-3_hops2.json --max-inputs 5000 --chunk-size 100
:return:
"""
global chunk_size
# Suppress basic RDKit logging...
RDLogger.logger().setLevel(RDLogger.ERROR)
print('RDKit version:', rdBase.rdkitVersion)
parser = argparse.ArgumentParser(description='Tether prep - prepare candidates for docking')
parser.add_argument('-i', '--input', help='JSON containing the expanded candidates)')
parser.add_argument('-o', '--outfile', default='tethered', help='Name for results (will generate something like Mpro-x0072_tethered_000.sdf)')
parser.add_argument('--min-ph', type=float, help='The min pH to consider')
parser.add_argument('--max-ph', type=float, help='The max pH to consider')
parser.add_argument('-c', '--chunk-size', type=int, default=200, help='Chunk size for files')
parser.add_argument('--max-inputs', type=int, default=0, help='Max number of molecules to process')
parser.add_argument('--max-outputs', type=int, default=0, help='Max number of records to output')
parser.add_argument('--modulus', type=int, default=0, help='Process only mols with this modulus')
args = parser.parse_args()
log("Tether prep args: ", args)
chunk_size = args.chunk_size
min_ph = args.min_ph
max_ph = args.max_ph
infile = args.input
outfile = args.outfile
max_inputs = args.max_inputs
max_outputs = args.max_outputs
modulus = args.modulus
# Dimporphite needs to use argparse with its own arguments, not messed up with our arguments
# so we store the original args
orig_sys_argv = sys.argv[:]
# Remove all the parameters, keeping only the filename (first one) so that
# dimorphite is unaware of any previous commandline parameters.
sys.argv = sys.argv[:1]
execute(infile, outfile, min_ph=min_ph, max_ph=max_ph, max_inputs=max_inputs, max_outputs=max_outputs, modulus=modulus)
print('Swap successes:', num_swap_success)
total = 0
for id, count in molecule_counts.items():
print(id, 'generated', count )
total += count
print("Total", total)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment