Skip to content

Instantly share code, notes, and snippets.

@maxentile
Created May 29, 2020 18:58
Show Gist options
  • Save maxentile/5c99e8d22f7219c1af26029e6cb874e5 to your computer and use it in GitHub Desktop.
Save maxentile/5c99e8d22f7219c1af26029e6cb874e5 to your computer and use it in GitHub Desktop.
[wip] interpreting a restricted smirks subset
"""parse subset of smirks that covers "self-contained" patterns
like [<atom>:1] <bond> [<atom>:2]
but not like [<atom>:1] <bond> [<atom>:2]=[<atom>]
or [<atom>:1] <bond> [<atom>$(*~[#6]):2]
later, it's straightforward to extend to handle recursive smarts (like "$(*~[#6])")
but less straightforward to handle extended chemical environments
"""
# openff imports
import openforcefield
from openforcefield.typing.engines.smirnoff import ForceField
forcefield = ForceField('openff-1.0.0.offxml')
print(openforcefield._version.get_versions())
import re
def extract_all_smirks(forcefield_string):
prefix = 'smirks="'
matches = re.findall(r'smirks=.*"', forcefield_string)
return [m.split()[0][len(prefix):-1] for m in matches]
def extract_all_bond_smirks(forcefield_string):
start_ind = forcefield_string.find('<Bonds')
end_ind = forcefield_string.find('</Bonds')
return extract_all_smirks(forcefield_string[start_ind:end_ind])
def extract_all_angle_smirks(forcefield_string):
start_ind = forcefield_string.find('<Angles')
end_ind = forcefield_string.find('</Angles')
return extract_all_smirks(forcefield_string[start_ind:end_ind])
def extract_all_torsion_smirks(forcefield_string):
start_ind = forcefield_string.find('<ProperTorsions')
end_ind = forcefield_string.find('</ProperTorsions')
return extract_all_smirks(forcefield_string[start_ind:end_ind])
from lark import Lark
numerical_primitives = [
"atomic_number", "total_connectivity",
"ring_connectivity", "total_h_count", "positive_charge",
"negative_charge", "ring_size"
]
boolean_primitives = ["aromatic", "wildcard"]
binary_operators = ["and", "or"]
unary_operators = ["not", "primitive_or_negation", ]
many_operators = ["implicit_and", "atom", "atomic_primitive_expr"]
from functools import reduce
def interpret(tree, atom: dict):
"""loosely modeled on
https://www.cs.cornell.edu/~asampson/blog/minisynth.html
"""
op = tree.data
if op in numerical_primitives:
# ring size can optionally not have a number
if op == "ring_size":
if len(tree.children) == 0:
return atom[op] > 0
# otherwise, continue...
assert (len(tree.children) == 1)
val = int(tree.children[0])
if op == "positive_charge":
return atom["formal_charge"] == val
elif op == "negative_charge":
return atom["formal_charge"] == -val
else:
return atom[op] == val
elif op in boolean_primitives:
return atom[op]
elif op in binary_operators:
lhs = interpret(tree.children[0], atom)
rhs = interpret(tree.children[1], atom)
if op == "and":
return lhs and rhs
elif op == "or":
return lhs or rhs
elif op in unary_operators:
assert (len(tree.children) == 1)
if op == "not":
return not interpret(tree.children[0], atom)
else:
return interpret(tree.children[0], atom)
elif op in many_operators:
if len(tree.children) == 1:
return interpret(tree.children[0], atom)
values = [interpret(c, atom) for c in tree.children]
if op in ["implicit_and", "atom", "atomic_primitive_expr"]:
return reduce(lambda a, b: a and b, values)
atoms_grammar = """
// atom
atom1 : atom ":1"
atom2 : atom ":2"
atom3 : atom ":3"
atom4 : atom ":4"
primitive_or_negation :
| atomic_primitive
| "!" atomic_primitive -> not
implicit_and :
| primitive_or_negation ~ 2..3
atomic_primitive_expr :
| primitive_or_negation
| implicit_and
| atomic_primitive_expr "&" atomic_primitive_expr -> and
| atomic_primitive_expr ";" atomic_primitive_expr -> and
| atomic_primitive_expr "," atomic_primitive_expr -> or
atom : atomic_primitive_expr
// atomic primitives
atomic_primitive :
| "*" -> wildcard
| "#" NUMBER -> atomic_number
| "a" -> aromatic
| "X" NUMBER -> total_connectivity
| "x" NUMBER -> ring_connectivity
| "H" NUMBER -> total_h_count
| "+" NUMBER -> positive_charge
| "-" NUMBER -> negative_charge
| "r" NUMBER* -> ring_size
connective :
| "&" -> and
| ";" -> and
| "," -> or
"""
bonds_grammar = """
// labeled bonds
bond12 : bond
bond23 : bond
bond34 : bond
// bond
bond : bond_primitive_expr ~ 1..3
| bond_primitive_expr connective bond_primitive_expr
| "!" bond_primitive_expr -> not
bond_primitive_expr :
| bond_primitive
| "!" bond_primitive_expr -> not
// bond primitives
bond_primitive :
| "~" -> any_bond
| "-" -> single_bond
| "=" -> double_bond
| "#" -> triple_bond
| ":" -> aromatic_bond
| "@" -> any_ring_bond
%import common.NUMBER
"""
bond_environment_template = """
bond_environment : "[" atom1 "]" bond12 "[" atom2 "]"
"""
angle_environment_template = """
angle_environment : "[" atom1 "]" bond12 "[" atom2 "]" bond23 "[" atom3 "]"
"""
torsion_environment_template = """
torsion_environment : "[" atom1 "]" bond12 "[" atom2 "]" bond23 "[" atom3 "]" bond34 "[" atom4 "]"
"""
bond_environment_grammar = bond_environment_template + atoms_grammar + bonds_grammar
angle_environment_grammar = angle_environment_template + atoms_grammar + bonds_grammar
torsion_environment_grammar = torsion_environment_template + atoms_grammar + bonds_grammar
bond_parser = Lark(bond_environment_grammar, start='bond_environment')
angle_parser = Lark(angle_environment_grammar, start='angle_environment')
torsion_parser = Lark(torsion_environment_grammar, start='torsion_environment')
def try_to_parse(smirks_list, parser):
successes, failures, errors = [], [], []
for s in smirks_list:
try:
parser.parse(s)
successes.append(s)
except Exception as e:
failures.append(s)
errors.append(e)
return successes, failures, errors
def format_parse_error(e):
lines = str(e).split('\n')
return '\n'.join(lines[2:4])
def get_atom_environments(tree):
return list(filter(lambda subtree: subtree.data == 'atom', tree.iter_subtrees()))
if __name__ == '__main__':
ff_string = forcefield.to_string()
bond_smirks = extract_all_bond_smirks(ff_string)
angle_smirks = extract_all_angle_smirks(ff_string)
torsion_smirks = extract_all_torsion_smirks(ff_string)
with open('smirks.txt', 'w') as f:
f.writelines('\n'.join(bond_smirks + angle_smirks + torsion_smirks))
blocks = ['bonds', 'angles', 'torsions']
smirks_lists = dict(zip(blocks, [bond_smirks, angle_smirks, torsion_smirks]))
grammars = dict(zip(blocks, [bond_environment_grammar, angle_environment_grammar, torsion_environment_grammar]))
parsers = dict(zip(blocks, [bond_parser, angle_parser, torsion_parser]))
successes, failures, errors = dict(), dict(), dict()
for block in blocks:
successes[block], failures[block], errors[block] = try_to_parse(smirks_lists[block], parsers[block])
for block in blocks:
print(f'{block}: {len(successes[block])} parsed, {len(failures[block])} not parsed')
for block in blocks:
print('-' * 30)
print(f'{block} patterns not covered')
for e in errors[block]:
print(format_parse_error(e))
print('-' * 30)
from collections import defaultdict
atom_environments = defaultdict(lambda: list())
for block in blocks:
for smirks in successes[block]:
parse_tree = parsers[block].parse(smirks)
atom_environments[block].extend(get_atom_environments(parse_tree))
n_environments, n_unique = len(atom_environments[block]), len(set(atom_environments[block]))
print(f'{block}:\n\t# of atom enviroments: {n_environments}\n\t# syntactically unique: {n_unique}')
# in total
all_atom_environments = []
for block in blocks:
all_atom_environments.extend(atom_environments[block])
n_environments, n_unique = len(all_atom_environments), len(set(all_atom_environments))
print(f'altogether:\n\t# of atom enviroments: {n_environments}\n\t# syntactically unique: {n_unique}')
# TODO: print the actual pattern, instead of the parse tree
# for block in blocks:
# reconstructor = Reconstructor(atoms_grammar)
# Next, check for logical equivalence
# TODO: think about how to do this... CNF?
# or getting ahead of self, this is just an optimization: it's fine to do some repeated work...
# Next, extract just the atomic primitives involved
# mol = Molecule.from_smiles("CCC")
# atoms = list(mol.atoms)
# atom = atoms[0]
unique = set(all_atom_environments)
for t in unique:
print(t.pretty())
atom = dict(
wildcard=True,
total_connectivity=2,
ring_connectivity=2,
ring_size=5,
atomic_number=7,
formal_charge=1,
total_h_count=1,
aromatic=False,
)
for t in unique:
# print(t.pretty())
result = interpret(t, atom)
if result not in [True, False]:
print('woops!')
print(result)
print(t.pretty())
# TODO: inspect bond expressions
# TODO: encode boolean expressions using sympy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment