Skip to content

Instantly share code, notes, and snippets.

@maxentile
Created February 21, 2020 19:36
Show Gist options
  • Save maxentile/dcf438b0300b48534341a0cbf80d7711 to your computer and use it in GitHub Desktop.
Save maxentile/dcf438b0300b48534341a0cbf80d7711 to your computer and use it in GitHub Desktop.
import numpy as np
from numba import jit
AND = np.bitwise_and
def off_mol_to_arrays(mol):
"""accept an openforcefield.topology.Molecule
return
* an (n_atoms, 2) integer array containing (atomic_number, formal_charge) for each atom
* an (n_bonds, 3) integer array containing (index1, index2, bond_order) for each bond
"""
atoms, bonds = mol.atoms, mol.bonds
atom_array = np.zeros((mol.n_atoms, 2), dtype=np.int32)
bond_array = np.zeros((mol.n_bonds, 3), dtype=np.int32)
for i, atom in enumerate(atoms):
atom_array[i] = atom.atomic_number, atom.formal_charge
for i, bond in enumerate(bonds):
bond_array[i] = bond.atom1_index, bond.atom2_index, bond.bond_order
return atom_array, bond_array
@jit
def compute_connectivity(atom_array, bond_array):
"""accept (atom_array, bond_array) as output by off_mol_to_arrays
return
* an (n_atoms,) array X, (X[i] = total # neighbors of atom i)
* an (n_atoms,) array H, (H[i] = # of hydrogen neighbors of atom i)
"""
X = np.zeros_like(atom_array[:, 0], dtype=np.int8)
H = np.zeros_like(atom_array[:, 0], dtype=np.int8)
for i in range(len(bond_array)):
a, b = bond_array[i, 0], bond_array[i, 1]
X[a] += 1
X[b] += 1
if atom_array[a, 0] == 1:
H[b] += 1
if atom_array[b, 0] == 1:
H[a] += 1
return X, H
@jit(cache=True)
def any_neighbors_w_attribute(binary_attribute_array, bond_array):
"""loop over bonds
return
* (n_atoms,) array, where presence[i] means atom[i] has at least
one neighbor j such that binary_attribute_array[j]==True
"""
presence = np.zeros_like(binary_attribute_array, dtype=np.bool)
for i in range(len(bond_array)):
a, b = bond_array[i, 0], bond_array[i, 1]
if binary_attribute_array[a]:
presence[b] = True
if binary_attribute_array[b]:
presence[a] = True
return presence
@jit(cache=True)
def n_neighbors_w_attribute(binary_attribute_array, bond_array):
"""loop over bonds
return
* an (n_atoms,) array, where counts[i] == n means atom[i] has exactly
n neighbors j such that binary_attribute_array[j]==True
"""
counts = np.zeros_like(binary_attribute_array, dtype=np.int8)
for i in range(len(bond_array)):
a, b = bond_array[i, 0], bond_array[i, 1]
if binary_attribute_array[a]:
counts[b] += 1
if binary_attribute_array[b]:
counts[a] += 1
return counts
def fast_vdw_matching(atom_array, bond_array):
"""return an (n_atoms, 35) binary array of smarts matches
TODO: speed this up even further if needed, currently only ~200x faster than toolkit,
but this is such a simple function it should really be on the order of nanoseconds not microseconds...
"""
single_bond_array = bond_array[bond_array[:, 2] == 1]
atomic_number = atom_array[:, 0]
formal_charge = atom_array[:, 1]
# charge
neutral = formal_charge == 0
pos = np.bitwise_or(formal_charge == 1, formal_charge == 2)
plus_1 = formal_charge == 1
minus_1 = formal_charge == -1
X, H = compute_connectivity(atom_array, bond_array)
X0 = X == 0
X2 = X == 2
X3 = X == 3
X4 = X == 4
H0 = H == 0
H1 = H == 1
# atomic number primitives
hydrogen = atomic_number == 1
lithium = atomic_number == 3
carbon = atomic_number == 6
nitrogen = atomic_number == 7
oxygen = atomic_number == 8
fluorine = atomic_number == 9
sodium = atomic_number == 11
phosphorus = atomic_number == 15
sulfur = atomic_number == 16
chlorine = atomic_number == 17
potassium = atomic_number == 19
bromine = atomic_number == 35
rubidium = atomic_number == 37 # TODO: huh, exotic enough to probably be a typo? (seems to have been introduced in smirnoff99Frosst 1.0.6...)
iodine = atomic_number == 53
cesium = atomic_number == 55
# derived atomic primitive attributes
group = [7, 8, 9, 16, 17, 35]
is_in_group = np.isin(atomic_number, group)
cx2 = AND(carbon, X2)
cx3 = AND(carbon, X3)
cx4 = AND(carbon, X4)
# bonded to positive
b_positive = any_neighbors_w_attribute(pos, bond_array)
# num neighbors (or single-bond-neighbors) in group
num_neighbors_in_group = n_neighbors_w_attribute(is_in_group, bond_array)
g_atl_1 = num_neighbors_in_group >= 1
g_atl_2 = num_neighbors_in_group >= 2
num_sb_neighbors_in_group = n_neighbors_w_attribute(is_in_group, single_bond_array)
sb_g_atl_1 = num_sb_neighbors_in_group >= 1
sb_g_atl_2 = num_sb_neighbors_in_group >= 2
sb_g_atl_3 = num_sb_neighbors_in_group >= 3
# TODO: can parallelize over binary attributes also...
sb_7 = any_neighbors_w_attribute(nitrogen, single_bond_array)
sb_8 = any_neighbors_w_attribute(oxygen, single_bond_array)
sb_16 = any_neighbors_w_attribute(sulfur, single_bond_array)
sb_to_cx4_w_pos_neighbor = any_neighbors_w_attribute(AND(b_positive, cx4), single_bond_array)
sb_to_cx3_g_atl1 = any_neighbors_w_attribute(AND(g_atl_1, cx3), single_bond_array)
sb_to_cx3_g_atl2 = any_neighbors_w_attribute(AND(g_atl_2, cx3), single_bond_array)
sb_to_cx4_sb_g_atl1 = any_neighbors_w_attribute(AND(sb_g_atl_1, cx4), single_bond_array)
sb_to_cx4_sb_g_atl2 = any_neighbors_w_attribute(AND(sb_g_atl_2, cx4), single_bond_array)
sb_to_cx4_sb_g_atl3 = any_neighbors_w_attribute(AND(sb_g_atl_3, cx4), single_bond_array)
# single-bonded to cx2, cx3, cx4
sb_cx2 = any_neighbors_w_attribute(cx2, single_bond_array)
sb_cx3 = any_neighbors_w_attribute(cx3, single_bond_array)
sb_cx4 = any_neighbors_w_attribute(cx4, single_bond_array)
# form match matrix
matches = np.zeros((35, len(atom_array)), dtype=np.bool)
matches[0] = hydrogen # 0 "[#1:1]"
matches[1] = AND(hydrogen, sb_cx4) # 1 "[#1:1]-[#6X4]"
matches[2] = AND(hydrogen, sb_to_cx4_sb_g_atl1) # 2 "[#1:1]-[#6X4]-[#7,#8,#9,#16,#17,#35]"
matches[3] = AND(hydrogen, sb_to_cx4_sb_g_atl2) # 3 "[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]"
matches[4] = AND(hydrogen, sb_to_cx4_sb_g_atl3) # 4 "[#1:1]-[#6X4](-[#7,#8,#9,#16,#17,#35])(-[#7,#8,#9,#16,#17,#35])-[#7,#8,#9,#16,#17,#35]"
matches[5] = AND(hydrogen, sb_to_cx4_w_pos_neighbor) # 5 "[#1:1]-[#6X4]~[*+1,*+2]"
matches[6] = AND(hydrogen, sb_cx3) # 6 "[#1:1]-[#6X3]"
matches[7] = AND(hydrogen, sb_to_cx3_g_atl1) # 7 "[#1:1]-[#6X3]~[#7,#8,#9,#16,#17,#35]"
matches[8] = AND(hydrogen, sb_to_cx3_g_atl2) # 8 "[#1:1]-[#6X3](~[#7,#8,#9,#16,#17,#35])~[#7,#8,#9,#16,#17,#35]"
matches[9] = AND(hydrogen, sb_cx2) # 9 "[#1:1]-[#6X2]"
matches[10] = AND(hydrogen, sb_7) # 10 "[#1:1]-[#7]"
matches[11] = AND(hydrogen, sb_8) # 11 "[#1:1]-[#8]"
matches[12] = AND(hydrogen, sb_16) # 12 "[#1:1]-[#16]"
matches[13] = carbon # 13 "[#6:1]"
matches[14] = AND(carbon, X2) # 14 "[#6X2:1]"
matches[15] = AND(carbon, X4) # 15 "[#6X4:1]"
matches[16] = oxygen # 16 "[#8:1]"
matches[17] = AND(AND(oxygen, X2), AND(H0, neutral)) # 17 "[#8X2H0+0:1]"
matches[18] = AND(AND(oxygen, X2), AND(H1, neutral)) # 18 "[#8X2H1+0:1]"
matches[19] = nitrogen # 19 "[#7:1]"
matches[20] = sulfur # 20 "[#16:1]"
matches[21] = phosphorus # 21 "[#15:1]"
matches[22] = fluorine # 22 "[#9:1]"
matches[23] = chlorine # 23 "[#17:1]"
matches[24] = bromine # 24 "[#35:1]"
matches[25] = iodine # 25 "[#53:1]"
matches[26] = AND(lithium, plus_1) # 26 "[#3+1:1]"
matches[27] = AND(sodium, plus_1) # 27 "[#11+1:1]"
matches[28] = AND(potassium, plus_1) # 28 "[#19+1:1]"
matches[29] = AND(rubidium, plus_1) # 29 "[#37+1:1]"
matches[30] = AND(cesium, plus_1) # 30 "[#55+1:1]"
matches[31] = AND(AND(fluorine, X0), minus_1) # 31 "[#9X0-1:1]"
matches[32] = AND(AND(chlorine, X0), minus_1) # 32 "[#17X0-1:1]"
matches[33] = AND(AND(bromine, X0), minus_1) # 33 "[#35X0-1:1]"
matches[34] = AND(AND(iodine, X0), minus_1) # 34 "[#53X0-1:1]"
return matches.T
if __name__ == '__main__':
from tqdm import tqdm
from pickle import load
print('loading mols and matches...')
with open('mols_and_matches.pkl', 'rb') as f:
mols, correct_matches = load(f)
print('calling matcher for first time...')
fast_matches = []
mol_arrays = list(map(off_mol_to_arrays, mols))
atom_array, bond_array = mol_arrays[0]
fast_vdw_matching(atom_array, bond_array)
print('startin main loop...')
for (atom_array, bond_array) in tqdm(mol_arrays):
fast_matches.append(fast_vdw_matching(atom_array, bond_array))
print('checking correctness...')
for i, mol in enumerate(mols):
agreement = correct_matches[i] == fast_matches[i]
disagreement = np.bitwise_not(agreement)
if not (np.alltrue(agreement)):
print('problem!')
print(np.sum(disagreement, 0))
print(np.where(disagreement > 0))
print(np.sum(disagreement, 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment