Created
May 29, 2020 18:58
-
-
Save maxentile/5c99e8d22f7219c1af26029e6cb874e5 to your computer and use it in GitHub Desktop.
[wip] interpreting a restricted smirks subset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""parse subset of smirks that covers "self-contained" patterns | |
like [<atom>:1] <bond> [<atom>:2] | |
but not like [<atom>:1] <bond> [<atom>:2]=[<atom>] | |
or [<atom>:1] <bond> [<atom>$(*~[#6]):2] | |
later, it's straightforward to extend to handle recursive smarts (like "$(*~[#6])") | |
but less straightforward to handle extended chemical environments | |
""" | |
# openff imports | |
import openforcefield | |
from openforcefield.typing.engines.smirnoff import ForceField | |
forcefield = ForceField('openff-1.0.0.offxml') | |
print(openforcefield._version.get_versions()) | |
import re | |
def extract_all_smirks(forcefield_string): | |
prefix = 'smirks="' | |
matches = re.findall(r'smirks=.*"', forcefield_string) | |
return [m.split()[0][len(prefix):-1] for m in matches] | |
def extract_all_bond_smirks(forcefield_string): | |
start_ind = forcefield_string.find('<Bonds') | |
end_ind = forcefield_string.find('</Bonds') | |
return extract_all_smirks(forcefield_string[start_ind:end_ind]) | |
def extract_all_angle_smirks(forcefield_string): | |
start_ind = forcefield_string.find('<Angles') | |
end_ind = forcefield_string.find('</Angles') | |
return extract_all_smirks(forcefield_string[start_ind:end_ind]) | |
def extract_all_torsion_smirks(forcefield_string): | |
start_ind = forcefield_string.find('<ProperTorsions') | |
end_ind = forcefield_string.find('</ProperTorsions') | |
return extract_all_smirks(forcefield_string[start_ind:end_ind]) | |
from lark import Lark | |
numerical_primitives = [ | |
"atomic_number", "total_connectivity", | |
"ring_connectivity", "total_h_count", "positive_charge", | |
"negative_charge", "ring_size" | |
] | |
boolean_primitives = ["aromatic", "wildcard"] | |
binary_operators = ["and", "or"] | |
unary_operators = ["not", "primitive_or_negation", ] | |
many_operators = ["implicit_and", "atom", "atomic_primitive_expr"] | |
from functools import reduce | |
def interpret(tree, atom: dict): | |
"""loosely modeled on | |
https://www.cs.cornell.edu/~asampson/blog/minisynth.html | |
""" | |
op = tree.data | |
if op in numerical_primitives: | |
# ring size can optionally not have a number | |
if op == "ring_size": | |
if len(tree.children) == 0: | |
return atom[op] > 0 | |
# otherwise, continue... | |
assert (len(tree.children) == 1) | |
val = int(tree.children[0]) | |
if op == "positive_charge": | |
return atom["formal_charge"] == val | |
elif op == "negative_charge": | |
return atom["formal_charge"] == -val | |
else: | |
return atom[op] == val | |
elif op in boolean_primitives: | |
return atom[op] | |
elif op in binary_operators: | |
lhs = interpret(tree.children[0], atom) | |
rhs = interpret(tree.children[1], atom) | |
if op == "and": | |
return lhs and rhs | |
elif op == "or": | |
return lhs or rhs | |
elif op in unary_operators: | |
assert (len(tree.children) == 1) | |
if op == "not": | |
return not interpret(tree.children[0], atom) | |
else: | |
return interpret(tree.children[0], atom) | |
elif op in many_operators: | |
if len(tree.children) == 1: | |
return interpret(tree.children[0], atom) | |
values = [interpret(c, atom) for c in tree.children] | |
if op in ["implicit_and", "atom", "atomic_primitive_expr"]: | |
return reduce(lambda a, b: a and b, values) | |
atoms_grammar = """ | |
// atom | |
atom1 : atom ":1" | |
atom2 : atom ":2" | |
atom3 : atom ":3" | |
atom4 : atom ":4" | |
primitive_or_negation : | |
| atomic_primitive | |
| "!" atomic_primitive -> not | |
implicit_and : | |
| primitive_or_negation ~ 2..3 | |
atomic_primitive_expr : | |
| primitive_or_negation | |
| implicit_and | |
| atomic_primitive_expr "&" atomic_primitive_expr -> and | |
| atomic_primitive_expr ";" atomic_primitive_expr -> and | |
| atomic_primitive_expr "," atomic_primitive_expr -> or | |
atom : atomic_primitive_expr | |
// atomic primitives | |
atomic_primitive : | |
| "*" -> wildcard | |
| "#" NUMBER -> atomic_number | |
| "a" -> aromatic | |
| "X" NUMBER -> total_connectivity | |
| "x" NUMBER -> ring_connectivity | |
| "H" NUMBER -> total_h_count | |
| "+" NUMBER -> positive_charge | |
| "-" NUMBER -> negative_charge | |
| "r" NUMBER* -> ring_size | |
connective : | |
| "&" -> and | |
| ";" -> and | |
| "," -> or | |
""" | |
bonds_grammar = """ | |
// labeled bonds | |
bond12 : bond | |
bond23 : bond | |
bond34 : bond | |
// bond | |
bond : bond_primitive_expr ~ 1..3 | |
| bond_primitive_expr connective bond_primitive_expr | |
| "!" bond_primitive_expr -> not | |
bond_primitive_expr : | |
| bond_primitive | |
| "!" bond_primitive_expr -> not | |
// bond primitives | |
bond_primitive : | |
| "~" -> any_bond | |
| "-" -> single_bond | |
| "=" -> double_bond | |
| "#" -> triple_bond | |
| ":" -> aromatic_bond | |
| "@" -> any_ring_bond | |
%import common.NUMBER | |
""" | |
bond_environment_template = """ | |
bond_environment : "[" atom1 "]" bond12 "[" atom2 "]" | |
""" | |
angle_environment_template = """ | |
angle_environment : "[" atom1 "]" bond12 "[" atom2 "]" bond23 "[" atom3 "]" | |
""" | |
torsion_environment_template = """ | |
torsion_environment : "[" atom1 "]" bond12 "[" atom2 "]" bond23 "[" atom3 "]" bond34 "[" atom4 "]" | |
""" | |
bond_environment_grammar = bond_environment_template + atoms_grammar + bonds_grammar | |
angle_environment_grammar = angle_environment_template + atoms_grammar + bonds_grammar | |
torsion_environment_grammar = torsion_environment_template + atoms_grammar + bonds_grammar | |
bond_parser = Lark(bond_environment_grammar, start='bond_environment') | |
angle_parser = Lark(angle_environment_grammar, start='angle_environment') | |
torsion_parser = Lark(torsion_environment_grammar, start='torsion_environment') | |
def try_to_parse(smirks_list, parser): | |
successes, failures, errors = [], [], [] | |
for s in smirks_list: | |
try: | |
parser.parse(s) | |
successes.append(s) | |
except Exception as e: | |
failures.append(s) | |
errors.append(e) | |
return successes, failures, errors | |
def format_parse_error(e): | |
lines = str(e).split('\n') | |
return '\n'.join(lines[2:4]) | |
def get_atom_environments(tree): | |
return list(filter(lambda subtree: subtree.data == 'atom', tree.iter_subtrees())) | |
if __name__ == '__main__': | |
ff_string = forcefield.to_string() | |
bond_smirks = extract_all_bond_smirks(ff_string) | |
angle_smirks = extract_all_angle_smirks(ff_string) | |
torsion_smirks = extract_all_torsion_smirks(ff_string) | |
with open('smirks.txt', 'w') as f: | |
f.writelines('\n'.join(bond_smirks + angle_smirks + torsion_smirks)) | |
blocks = ['bonds', 'angles', 'torsions'] | |
smirks_lists = dict(zip(blocks, [bond_smirks, angle_smirks, torsion_smirks])) | |
grammars = dict(zip(blocks, [bond_environment_grammar, angle_environment_grammar, torsion_environment_grammar])) | |
parsers = dict(zip(blocks, [bond_parser, angle_parser, torsion_parser])) | |
successes, failures, errors = dict(), dict(), dict() | |
for block in blocks: | |
successes[block], failures[block], errors[block] = try_to_parse(smirks_lists[block], parsers[block]) | |
for block in blocks: | |
print(f'{block}: {len(successes[block])} parsed, {len(failures[block])} not parsed') | |
for block in blocks: | |
print('-' * 30) | |
print(f'{block} patterns not covered') | |
for e in errors[block]: | |
print(format_parse_error(e)) | |
print('-' * 30) | |
from collections import defaultdict | |
atom_environments = defaultdict(lambda: list()) | |
for block in blocks: | |
for smirks in successes[block]: | |
parse_tree = parsers[block].parse(smirks) | |
atom_environments[block].extend(get_atom_environments(parse_tree)) | |
n_environments, n_unique = len(atom_environments[block]), len(set(atom_environments[block])) | |
print(f'{block}:\n\t# of atom enviroments: {n_environments}\n\t# syntactically unique: {n_unique}') | |
# in total | |
all_atom_environments = [] | |
for block in blocks: | |
all_atom_environments.extend(atom_environments[block]) | |
n_environments, n_unique = len(all_atom_environments), len(set(all_atom_environments)) | |
print(f'altogether:\n\t# of atom enviroments: {n_environments}\n\t# syntactically unique: {n_unique}') | |
# TODO: print the actual pattern, instead of the parse tree | |
# for block in blocks: | |
# reconstructor = Reconstructor(atoms_grammar) | |
# Next, check for logical equivalence | |
# TODO: think about how to do this... CNF? | |
# or getting ahead of self, this is just an optimization: it's fine to do some repeated work... | |
# Next, extract just the atomic primitives involved | |
# mol = Molecule.from_smiles("CCC") | |
# atoms = list(mol.atoms) | |
# atom = atoms[0] | |
unique = set(all_atom_environments) | |
for t in unique: | |
print(t.pretty()) | |
atom = dict( | |
wildcard=True, | |
total_connectivity=2, | |
ring_connectivity=2, | |
ring_size=5, | |
atomic_number=7, | |
formal_charge=1, | |
total_h_count=1, | |
aromatic=False, | |
) | |
for t in unique: | |
# print(t.pretty()) | |
result = interpret(t, atom) | |
if result not in [True, False]: | |
print('woops!') | |
print(result) | |
print(t.pretty()) | |
# TODO: inspect bond expressions | |
# TODO: encode boolean expressions using sympy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment