Created
February 21, 2020 19:36
-
-
Save maxentile/dcf438b0300b48534341a0cbf80d7711 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import jit | |
AND = np.bitwise_and | |
def off_mol_to_arrays(mol): | |
"""accept an openforcefield.topology.Molecule | |
return | |
* an (n_atoms, 2) integer array containing (atomic_number, formal_charge) for each atom | |
* an (n_bonds, 3) integer array containing (index1, index2, bond_order) for each bond | |
""" | |
atoms, bonds = mol.atoms, mol.bonds | |
atom_array = np.zeros((mol.n_atoms, 2), dtype=np.int32) | |
bond_array = np.zeros((mol.n_bonds, 3), dtype=np.int32) | |
for i, atom in enumerate(atoms): | |
atom_array[i] = atom.atomic_number, atom.formal_charge | |
for i, bond in enumerate(bonds): | |
bond_array[i] = bond.atom1_index, bond.atom2_index, bond.bond_order | |
return atom_array, bond_array | |
@jit | |
def compute_connectivity(atom_array, bond_array): | |
"""accept (atom_array, bond_array) as output by off_mol_to_arrays | |
return | |
* an (n_atoms,) array X, (X[i] = total # neighbors of atom i) | |
* an (n_atoms,) array H, (H[i] = # of hydrogen neighbors of atom i) | |
""" | |
X = np.zeros_like(atom_array[:, 0], dtype=np.int8) | |
H = np.zeros_like(atom_array[:, 0], dtype=np.int8) | |
for i in range(len(bond_array)): | |
a, b = bond_array[i, 0], bond_array[i, 1] | |
X[a] += 1 | |
X[b] += 1 | |
if atom_array[a, 0] == 1: | |
H[b] += 1 | |
if atom_array[b, 0] == 1: | |
H[a] += 1 | |
return X, H | |
@jit(cache=True) | |
def any_neighbors_w_attribute(binary_attribute_array, bond_array): | |
"""loop over bonds | |
return | |
* (n_atoms,) array, where presence[i] means atom[i] has at least | |
one neighbor j such that binary_attribute_array[j]==True | |
""" | |
presence = np.zeros_like(binary_attribute_array, dtype=np.bool) | |
for i in range(len(bond_array)): | |
a, b = bond_array[i, 0], bond_array[i, 1] | |
if binary_attribute_array[a]: | |
presence[b] = True | |
if binary_attribute_array[b]: | |
presence[a] = True | |
return presence | |
@jit(cache=True) | |
def n_neighbors_w_attribute(binary_attribute_array, bond_array): | |
"""loop over bonds | |
return | |
* an (n_atoms,) array, where counts[i] == n means atom[i] has exactly | |
n neighbors j such that binary_attribute_array[j]==True | |
""" | |
counts = np.zeros_like(binary_attribute_array, dtype=np.int8) | |
for i in range(len(bond_array)): | |
a, b = bond_array[i, 0], bond_array[i, 1] | |
if binary_attribute_array[a]: | |
counts[b] += 1 | |
if binary_attribute_array[b]: | |
counts[a] += 1 | |
return counts | |
def fast_vdw_matching(atom_array, bond_array): | |
"""return an (n_atoms, 35) binary array of smarts matches | |
TODO: speed this up even further if needed, currently only ~200x faster than toolkit, | |
but this is such a simple function it should really be on the order of nanoseconds not microseconds... | |
""" | |
single_bond_array = bond_array[bond_array[:, 2] == 1] | |
atomic_number = atom_array[:, 0] | |
formal_charge = atom_array[:, 1] | |
# charge | |
neutral = formal_charge == 0 | |
pos = np.bitwise_or(formal_charge == 1, formal_charge == 2) | |
plus_1 = formal_charge == 1 | |
minus_1 = formal_charge == -1 | |
X, H = compute_connectivity(atom_array, bond_array) | |
X0 = X == 0 | |
X2 = X == 2 | |
X3 = X == 3 | |
X4 = X == 4 | |
H0 = H == 0 | |
H1 = H == 1 | |
# atomic number primitives | |
hydrogen = atomic_number == 1 | |
lithium = atomic_number == 3 | |
carbon = atomic_number == 6 | |
nitrogen = atomic_number == 7 | |
oxygen = atomic_number == 8 | |
fluorine = atomic_number == 9 | |
sodium = atomic_number == 11 | |
phosphorus = atomic_number == 15 | |
sulfur = atomic_number == 16 | |
chlorine = atomic_number == 17 | |
potassium = atomic_number == 19 | |
bromine = atomic_number == 35 | |
rubidium = atomic_number == 37 # TODO: huh, exotic enough to probably be a typo? (seems to have been introduced in smirnoff99Frosst 1.0.6...) | |
iodine = atomic_number == 53 | |
cesium = atomic_number == 55 | |
# derived atomic primitive attributes | |
group = [7, 8, 9, 16, 17, 35] | |
is_in_group = np.isin(atomic_number, group) | |
cx2 = AND(carbon, X2) | |
cx3 = AND(carbon, X3) | |
cx4 = AND(carbon, X4) | |
# bonded to positive | |
b_positive = any_neighbors_w_attribute(pos, bond_array) | |
# num neighbors (or single-bond-neighbors) in group | |
num_neighbors_in_group = n_neighbors_w_attribute(is_in_group, bond_array) | |
g_atl_1 = num_neighbors_in_group >= 1 | |
g_atl_2 = num_neighbors_in_group >= 2 | |
num_sb_neighbors_in_group = n_neighbors_w_attribute(is_in_group, single_bond_array) | |
sb_g_atl_1 = num_sb_neighbors_in_group >= 1 | |
sb_g_atl_2 = num_sb_neighbors_in_group >= 2 | |
sb_g_atl_3 = num_sb_neighbors_in_group >= 3 | |
# TODO: can parallelize over binary attributes also... | |
sb_7 = any_neighbors_w_attribute(nitrogen, single_bond_array) | |
sb_8 = any_neighbors_w_attribute(oxygen, single_bond_array) | |
sb_16 = any_neighbors_w_attribute(sulfur, single_bond_array) | |
sb_to_cx4_w_pos_neighbor = any_neighbors_w_attribute(AND(b_positive, cx4), single_bond_array) | |
sb_to_cx3_g_atl1 = any_neighbors_w_attribute(AND(g_atl_1, cx3), single_bond_array) | |
sb_to_cx3_g_atl2 = any_neighbors_w_attribute(AND(g_atl_2, cx3), single_bond_array) | |
sb_to_cx4_sb_g_atl1 = any_neighbors_w_attribute(AND(sb_g_atl_1, cx4), single_bond_array) | |
sb_to_cx4_sb_g_atl2 = any_neighbors_w_attribute(AND(sb_g_atl_2, cx4), single_bond_array) | |
sb_to_cx4_sb_g_atl3 = any_neighbors_w_attribute(AND(sb_g_atl_3, cx4), single_bond_array) | |
# single-bonded to cx2, cx3, cx4 | |
sb_cx2 = any_neighbors_w_attribute(cx2, single_bond_array) | |
sb_cx3 = any_neighbors_w_attribute(cx3, single_bond_array) | |
sb_cx4 = any_neighbors_w_attribute(cx4, single_bond_array) | |
# form match matrix | |
matches = np.zeros((35, len(atom_array)), dtype=np.bool) | |
matches[0] = hydrogen # 0 "[#1:1]" | |
matches[1] = AND(hydrogen, sb_cx4) # 1 "[#1:1]-[#6X4]" | |
matches[2] = AND(hydrogen, sb_to_cx4_sb_g_atl1) # 2 "[#1:1]-[#6X4]-[#7,#8,#9,#16,#17,#35]" | |
matches[3] = AND(hydrogen, sb_to_cx4_sb_g_atl2) # 3 "[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]" | |
matches[4] = AND(hydrogen, sb_to_cx4_sb_g_atl3) # 4 "[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])(-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]" | |
matches[5] = AND(hydrogen, sb_to_cx4_w_pos_neighbor) # 5 "[#1:1]-[#6X4]~[*+1,*+2]" | |
matches[6] = AND(hydrogen, sb_cx3) # 6 "[#1:1]-[#6X3]" | |
matches[7] = AND(hydrogen, sb_to_cx3_g_atl1) # 7 "[#1:1]-[#6X3]~[#7,#8,#9,#16,#17,#35]" | |
matches[8] = AND(hydrogen, sb_to_cx3_g_atl2) # 8 "[#1:1]-[#6X3](~[#7,#8,#9,#16,#17,#35])~[#7,#8,#9,#16,#17,#35]" | |
matches[9] = AND(hydrogen, sb_cx2) # 9 "[#1:1]-[#6X2]" | |
matches[10] = AND(hydrogen, sb_7) # 10 "[#1:1]-[#7]" | |
matches[11] = AND(hydrogen, sb_8) # 11 "[#1:1]-[#8]" | |
matches[12] = AND(hydrogen, sb_16) # 12 "[#1:1]-[#16]" | |
matches[13] = carbon # 13 "[#6:1]" | |
matches[14] = AND(carbon, X2) # 14 "[#6X2:1]" | |
matches[15] = AND(carbon, X4) # 15 "[#6X4:1]" | |
matches[16] = oxygen # 16 "[#8:1]" | |
matches[17] = AND(AND(oxygen, X2), AND(H0, neutral)) # 17 "[#8X2H0+0:1]" | |
matches[18] = AND(AND(oxygen, X2), AND(H1, neutral)) # 18 "[#8X2H1+0:1]" | |
matches[19] = nitrogen # 19 "[#7:1]" | |
matches[20] = sulfur # 20 "[#16:1]" | |
matches[21] = phosphorus # 21 "[#15:1]" | |
matches[22] = fluorine # 22 "[#9:1]" | |
matches[23] = chlorine # 23 "[#17:1]" | |
matches[24] = bromine # 24 "[#35:1]" | |
matches[25] = iodine # 25 "[#53:1]" | |
matches[26] = AND(lithium, plus_1) # 26 "[#3+1:1]" | |
matches[27] = AND(sodium, plus_1) # 27 "[#11+1:1]" | |
matches[28] = AND(potassium, plus_1) # 28 "[#19+1:1]" | |
matches[29] = AND(rubidium, plus_1) # 29 "[#37+1:1]" | |
matches[30] = AND(cesium, plus_1) # 30 "[#55+1:1]" | |
matches[31] = AND(AND(fluorine, X0), minus_1) # 31 "[#9X0-1:1]" | |
matches[32] = AND(AND(chlorine, X0), minus_1) # 32 "[#17X0-1:1]" | |
matches[33] = AND(AND(bromine, X0), minus_1) # 33 "[#35X0-1:1]" | |
matches[34] = AND(AND(iodine, X0), minus_1) # 34 "[#53X0-1:1]" | |
return matches.T | |
if __name__ == '__main__': | |
from tqdm import tqdm | |
from pickle import load | |
print('loading mols and matches...') | |
with open('mols_and_matches.pkl', 'rb') as f: | |
mols, correct_matches = load(f) | |
print('calling matcher for first time...') | |
fast_matches = [] | |
mol_arrays = list(map(off_mol_to_arrays, mols)) | |
atom_array, bond_array = mol_arrays[0] | |
fast_vdw_matching(atom_array, bond_array) | |
print('startin main loop...') | |
for (atom_array, bond_array) in tqdm(mol_arrays): | |
fast_matches.append(fast_vdw_matching(atom_array, bond_array)) | |
print('checking correctness...') | |
for i, mol in enumerate(mols): | |
agreement = correct_matches[i] == fast_matches[i] | |
disagreement = np.bitwise_not(agreement) | |
if not (np.alltrue(agreement)): | |
print('problem!') | |
print(np.sum(disagreement, 0)) | |
print(np.where(disagreement > 0)) | |
print(np.sum(disagreement, 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment