Skip to content

Instantly share code, notes, and snippets.

@maclandrol
Created June 4, 2021 21:13
Show Gist options
  • Save maclandrol/24c0372a54e08e1d4e31528fd4d9af79 to your computer and use it in GitHub Desktop.
Save maclandrol/24c0372a54e08e1d4e31528fd4d9af79 to your computer and use it in GitHub Desktop.
RDKit Multi hits highlighting
from collections import defaultdict as ddict
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import Image
import datamol as dm
import seaborn as sns
COLORS = list(sns.palettes.color_palette("pastel"))
def get_match_highlights(mol, pattern):
# idea is pretty simple
# the order return by the atom match is the same as the order of atom in the pattern
atom_matches = list(mol.GetSubstructMatches(pattern, uniquify=True))
bond_matches = ddict(list)
for bond in pattern.GetBonds():
for i, atom_match in enumerate(atom_matches):
a1 = atom_match[bond.GetBeginAtomIdx()]
a2 = atom_match[bond.GetEndAtomIdx()]
bond_matches[i].append(mol.GetBondBetweenAtoms(a1, a2).GetIdx())
return atom_matches, [bond_matches[i] for i in range(len(atom_matches))]
def get_color_map(atom_matches, bond_matches):
atom_colors, bond_colors = {}, {}
all_atoms, all_bonds = [], []
for i, (ats, bds) in enumerate(zip(atom_matches, bond_matches)):
for at in ats:
atom_colors[at] = COLORS[i%len(COLORS)]
all_atoms.append(at)
for bd in bds:
bond_colors[bd] = COLORS[i%len(COLORS)]
all_bonds.append(bd)
return atom_colors, bond_colors, all_atoms, all_bonds
def draw_hits(mol, pattern, mol_size=(500, 500)):
atoms, bonds = get_match_highlights(mol, pattern)
atom_colors, bond_colors, all_atoms, all_bonds = get_color_map(atoms, bonds)
d = rdMolDraw2D.MolDraw2DCairo(*mol_size)
d.drawOptions().addStereoAnnotation = True
d.drawOptions().bla = True
d.drawOptions().addAtomIndices = True
d.drawOptions().useBWAtomPalette()
rdMolDraw2D.PrepareAndDrawMolecule(d, mol,
highlightAtoms=all_atoms,
highlightAtomColors=atom_colors,
highlightBonds=all_bonds,
highlightBondColors=bond_colors
# highlightAtomRadii={1:0.5}) # fun with highlight radius for a given atom number
)
return Image(d.GetDrawingText())
##### TEST #####
mol = dm.to_mol('CC(=O)OC1=CC=CC=C1C(=O)O')
# you might need to run SSR or aro perception on the pattern in some cases
# shouldn't be necessary though
# cheating by selecting non overlapping hits
pattern = Chem.MolFromSmarts("[C;H0](=O)[O]")
display(draw_hits(mol, pattern))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment