-
-
Save maclandrol/24c0372a54e08e1d4e31528fd4d9af79 to your computer and use it in GitHub Desktop.
RDKit Multi hits highlighting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict as ddict | |
from rdkit import Chem | |
from rdkit.Chem.Draw import rdMolDraw2D | |
from IPython.display import Image | |
import datamol as dm | |
import seaborn as sns | |
COLORS = list(sns.palettes.color_palette("pastel")) | |
def get_match_highlights(mol, pattern): | |
# idea is pretty simple | |
# the order return by the atom match is the same as the order of atom in the pattern | |
atom_matches = list(mol.GetSubstructMatches(pattern, uniquify=True)) | |
bond_matches = ddict(list) | |
for bond in pattern.GetBonds(): | |
for i, atom_match in enumerate(atom_matches): | |
a1 = atom_match[bond.GetBeginAtomIdx()] | |
a2 = atom_match[bond.GetEndAtomIdx()] | |
bond_matches[i].append(mol.GetBondBetweenAtoms(a1, a2).GetIdx()) | |
return atom_matches, [bond_matches[i] for i in range(len(atom_matches))] | |
def get_color_map(atom_matches, bond_matches): | |
atom_colors, bond_colors = {}, {} | |
all_atoms, all_bonds = [], [] | |
for i, (ats, bds) in enumerate(zip(atom_matches, bond_matches)): | |
for at in ats: | |
atom_colors[at] = COLORS[i%len(COLORS)] | |
all_atoms.append(at) | |
for bd in bds: | |
bond_colors[bd] = COLORS[i%len(COLORS)] | |
all_bonds.append(bd) | |
return atom_colors, bond_colors, all_atoms, all_bonds | |
def draw_hits(mol, pattern, mol_size=(500, 500)): | |
atoms, bonds = get_match_highlights(mol, pattern) | |
atom_colors, bond_colors, all_atoms, all_bonds = get_color_map(atoms, bonds) | |
d = rdMolDraw2D.MolDraw2DCairo(*mol_size) | |
d.drawOptions().addStereoAnnotation = True | |
d.drawOptions().bla = True | |
d.drawOptions().addAtomIndices = True | |
d.drawOptions().useBWAtomPalette() | |
rdMolDraw2D.PrepareAndDrawMolecule(d, mol, | |
highlightAtoms=all_atoms, | |
highlightAtomColors=atom_colors, | |
highlightBonds=all_bonds, | |
highlightBondColors=bond_colors | |
# highlightAtomRadii={1:0.5}) # fun with highlight radius for a given atom number | |
) | |
return Image(d.GetDrawingText()) | |
##### TEST ##### | |
mol = dm.to_mol('CC(=O)OC1=CC=CC=C1C(=O)O') | |
# you might need to run SSR or aro perception on the pattern in some cases | |
# shouldn't be necessary though | |
# cheating by selecting non overlapping hits | |
pattern = Chem.MolFromSmarts("[C;H0](=O)[O]") | |
display(draw_hits(mol, pattern)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment