Last active
July 30, 2021 04:04
-
-
Save yashbonde/ab25e34834e71519178b93d3b5a26c19 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""simple script to train a simple graph network to train a simple drug classifier | |
18.04.2020 - @yashbonde""" | |
from tqdm import trange | |
import json | |
import numpy as np | |
from pysmiles import read_smiles | |
import networkx as nx | |
import pandas as pd # to make the final DF | |
from types import SimpleNamespace | |
from collections import namedtuple | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import torch_geometric as tgx | |
from torch_scatter import scatter_mean | |
from torch_geometric.data import Data | |
# network classes | |
class GNNBlock(nn.Module): | |
def __init__(self, node_dim, node_dim_hid, edge_dim, edge_dim_hid): | |
super(GNNBlock, self).__init__() | |
# edge mlp: [N+N+E, ehid] x [ehid, E] | |
self.edge_mlp = nn.Sequential( | |
nn.Linear(2*node_dim + edge_dim, edge_dim_hid), | |
nn.ReLU(), | |
nn.Linear(edge_dim_hid, edge_dim) | |
) | |
# node mlp1 [N+E, nhid] x [nhid, N] | |
self.node_mlp1 = nn.Sequential( | |
nn.Linear(node_dim + edge_dim, node_dim_hid), | |
nn.ReLU(), | |
nn.Linear(node_dim_hid, node_dim) | |
) | |
self.node_mlp2 = nn.Sequential( | |
nn.Linear(2 * node_dim, node_dim_hid), | |
nn.ReLU(), | |
nn.Linear(node_dim_hid, node_dim) | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
for item in [self.node_mlp1, self.node_mlp2, self.edge_mlp]: | |
if hasattr(item, 'reset_parameters'): | |
item.reset_parameters() | |
def __repr__(self): | |
return ('{}(\n' | |
' edge_mlp={},\n' | |
' node_mlp1={},\n' | |
')').format(self.__class__.__name__, self.edge_mlp, | |
self.node_mlp1) | |
def forward(self, x, edge_index, edge_attr): | |
""" | |
data.x: [N, F_x] | |
data.edge_index: [2, E] with max entry N-1 | |
data.edge_attr: [E, F_e] | |
""" | |
row, col = edge_index | |
# first we perform edge handling | |
edge_attr = torch.cat([x[row], x[col], edge_attr], dim=-1) | |
edge_attr = self.edge_mlp(edge_attr) # [E, F_e] | |
# second we perform node pass | |
node_out = torch.cat([x[row], edge_attr], dim=-1) # [1, F_x + F_e] | |
node_out = self.node_mlp1(node_out) # [N, F_x] | |
node_out = scatter_mean(node_out, col, dim=0, | |
dim_size=x.size(0)) # [N, F_x] | |
node_out = torch.cat([x, node_out], dim=-1) # [N, 2*F_x] | |
node_out = self.node_mlp2(node_out) # [N,F_x] | |
return node_out, edge_index, edge_attr | |
def test_sizes(self): | |
x = torch.randn(4, 8) | |
edge_index = torch.tensor([[0, 1, 1, 2, 1, 3], | |
[1, 0, 2, 1, 3, 1]], dtype=torch.long) | |
edge_attr = torch.randn(6, 8) | |
x, edge_index, edge_attr = self.forward(x, edge_index, edge_attr) | |
print(x.shape, edge_index.shape, edge_attr.shape) | |
assert x.shape == torch.Size([4, 8]) | |
assert edge_index.shape == torch.Size([2, 6]) | |
assert edge_attr.shape == torch.Size([6, 8]) | |
print("Test Passed") | |
self.reset_parameters() | |
class ClassisiferNetwork(nn.Module): | |
def __init__(self, | |
num_elements, | |
num_edges, | |
node_dim, | |
node_extra_feat, | |
edge_dim, | |
num_classes, | |
batch_size, | |
): | |
""" | |
num_elements: int | |
num_edges: int | |
node_dim: int | |
node_extra_feat: int | |
edge_dim: int | |
num_classes: int | |
""" | |
super(ClassisiferNetwork, self).__init__() | |
self.node_onehot = np.eye(num_elements).astype(np.float32) | |
self.edge_onehot = np.eye(num_edges).astype(np.float32) | |
self.hcount_onehot = np.eye(4).astype(np.float32) | |
self.animal_onehot = np.eye(4).astype(np.float32) | |
# linear projections | |
self.node_dense = nn.Sequential( | |
nn.Linear(num_elements + node_extra_feat, node_dim), | |
nn.ReLU() | |
) | |
self.edge_dense = nn.Sequential( | |
nn.Linear(num_edges, edge_dim), | |
nn.ReLU() | |
) | |
# gnn block + batch norms | |
self.gnn1 = GNNBlock(node_dim, node_dim * 2, edge_dim, 2 * edge_dim) | |
self.bn_n1 = tgx.nn.BatchNorm(node_dim) | |
self.bn_e1 = tgx.nn.BatchNorm(node_dim) | |
self.gnn2 = GNNBlock(node_dim, node_dim * 2, edge_dim, 2 * edge_dim) | |
self.bn_n2 = tgx.nn.BatchNorm(node_dim) | |
self.bn_e2 = tgx.nn.BatchNorm(node_dim) | |
self.gcn = tgx.nn.GCNConv(node_dim, node_dim) | |
self.bn_n3 = tgx.nn.BatchNorm(node_dim) | |
self.pred_mlp = nn.Linear(node_dim, num_classes) | |
def make_data(self, gx, y, an): | |
# get node information --> [N,3] | |
node_idx = np.array([node[1]["element"] | |
for node in gx.nodes.data()]).astype(np.int64) # long | |
node_embeddings = np.array([( | |
node[1]["charge"], | |
node[1]["aromatic"], | |
*self.hcount_onehot[node[1]["hcount"]] | |
)for node in gx.nodes.data()]).astype(np.float32) # long float | |
ele_embed = self.node_onehot[node_idx] | |
x = np.hstack([ele_embed, node_embeddings]) | |
# get edge information --> [2E,F_e] + [2, E] | |
edge_idx = [] | |
edge_attr = [] | |
for edge in gx.edges.data(): | |
edge_idx.append([edge[0], edge[1]]) | |
edge_idx.append([edge[1], edge[0]]) | |
edge_attr.extend([edge[2]["order"], ]*2) # edge features per edge | |
edge_idx = np.array(edge_idx).astype(np.int64) # long | |
edge_attr = np.array(edge_attr).astype(np.int64) # extra features | |
edge_index = edge_idx.T | |
edge_attr = self.edge_onehot[edge_attr] | |
return Data( | |
x=torch.from_numpy(x), | |
edge_index=torch.from_numpy(edge_index), | |
edge_attr=torch.from_numpy(edge_attr), | |
an_emb=torch.from_numpy(np.expand_dims( | |
self.animal_onehot[an], axis=0)), | |
y=torch.tensor(y, dtype=torch.long) | |
) | |
def forward(self, gx): | |
x, edge_index, edge_attr = gx.x, gx.edge_index, gx.edge_attr | |
# dense | |
node_out = self.node_dense(x) | |
edge_out = self.edge_dense(edge_attr) | |
# gnn_blocks + reduce values | |
x, edge_index, edge_attr = self.gnn1(node_out, edge_index, edge_out) | |
x = self.bn_n1(x) | |
edge_attr = self.bn_e1(edge_attr) | |
x, edge_index, edge_attr = self.gnn2(x, edge_index, edge_attr) | |
x = self.bn_n2(x) | |
edge_attr = self.bn_e2(edge_attr) | |
node_out = self.gcn(x, edge_index) | |
max_pooled_graph = tgx.nn.global_max_pool(node_out, gx.batch) | |
graph_emb = torch.cat([ | |
max_pooled_graph, | |
gx.an_emb | |
], dim=-1) | |
# final classification layer | |
carcinogenic_pred = self.pred_mlp(max_pooled_graph) | |
return carcinogenic_pred | |
# helper functions | |
def calculate_accuracy(target, pred): | |
target = torch.squeeze(target) | |
pred = torch.argmax(F.softmax(torch.squeeze(pred)), dim=1) | |
check = target == pred | |
acc = torch.sum(check).float()/check.shape[0] | |
return acc | |
def pre_process_molecule(gx): | |
for node in gx.nodes.data(): | |
node[1].update({ | |
"element": elements2idx[node[1]["element"]], | |
"aromatic": int(node[1]["aromatic"]), | |
"hcount": int(node[1]["hcount"]) | |
}) | |
if __name__ == "__main__": | |
# load the CSV | |
fdata = pd.read_csv("PTC_pn/merged.csv") | |
# load graphs | |
print("Making all graphs ...") | |
graphs_all = [] | |
for smiles in fdata["smiles"]: | |
graphs_all.append(read_smiles(smiles)) | |
print("... Done!") | |
print("Loading Node and Edge Attribute data ...") | |
# [NODE] load elements | |
elements = [] | |
for gx in graphs_all: | |
ele = [x[1]["element"] for x in gx.nodes.data()] | |
elements.extend(ele) | |
elements = list(set(elements)) | |
NUM_ELEMENTS = len(elements) | |
print("Elements: ", elements, NUM_ELEMENTS) | |
# [NODE] number of attached hydrogen Atoms | |
num_hcounts = [] | |
for gx in graphs_all: | |
num_hcounts.extend([x[1]["hcount"] for x in gx.nodes.data()]) | |
num_hcounts = list(set(num_hcounts)) | |
NUM_HCOUNTS = len(num_hcounts) | |
print("Number of attached Hydrogen atoms:", num_hcounts, NUM_HCOUNTS) | |
# [NODE] whether Aromatic or not | |
aromatic = [] | |
for gx in graphs_all: | |
aromatic.extend([x[1]["aromatic"] for x in gx.nodes.data()]) | |
aromatic = list(set(aromatic)) | |
print("aromatic:", aromatic) | |
# [NODE] what is the charge on the atom | |
charge = [] | |
for gx in graphs_all: | |
charge.extend([x[1]["charge"] for x in gx.nodes.data()]) | |
charge = list(set(charge)) | |
print("charge:", charge) | |
# [EDGE] information about the bond | |
bonds = [] | |
for gx in graphs_all: | |
bond = [x[2]["order"] for x in gx.edges.data()] | |
bonds.extend(bond) | |
bonds = list(set(bonds)) | |
NUM_BONDS = len(bonds) | |
print("information about the bond:", bonds, NUM_BONDS) | |
# load element2id | |
# elements2idx = json.load(open("PTC_pn/ele2id.json")) | |
elements2idx = {"Zn": 0, "Ba": 1, "Sn": 2, "C": 3, "Cl": 4, "K": 5, "S": 6, "N": 7, "Cu": 8, "Na": 9, "F": 10, "Ca": 11, "O": 12, "In": 13, "P": 14, "Pb": 15, "B": 16, "As": 17, "Te": 18, "I": 19, "Br": 20} | |
print("... Done!") | |
# preprocess the graphs | |
print("Preprocessing graphs ...") | |
for gx in graphs_all: | |
pre_process_molecule(gx) | |
print("... Done!") | |
cl = ClassisiferNetwork( | |
num_elements = NUM_ELEMENTS, | |
num_edges = NUM_BONDS, | |
node_dim = 32, | |
node_extra_feat = 2 + 4, | |
edge_dim = 32, | |
num_classes = 2, | |
batch_size = 20 | |
) | |
# create proper dataset mate! | |
dataset = [] | |
for idx, gx in enumerate(graphs_all): | |
if idx == 214: | |
# issue with graph | |
continue | |
targ_fm = int(fdata["carcinogenic_FM"][idx]) | |
targ_mm = int(fdata["carcinogenic_MM"][idx]) | |
targ_fr = int(fdata["carcinogenic_FR"][idx]) | |
targ_mr = int(fdata["carcinogenic_MR"][idx]) | |
# make data samples | |
if targ_fm != 0: | |
targ_fm = [max(targ_fm, 0),]*len(gx.nodes) | |
data = cl.make_data(gx, [targ_fm[0]], an = 0) | |
dataset.append(data) | |
if targ_mm != 0: | |
targ_mm = [max(targ_mm, 0),]*len(gx.nodes) | |
data = cl.make_data(gx, [targ_mm[0]], an = 1) | |
dataset.append(data) | |
if targ_fr != 0: | |
targ_fr = [max(targ_fr, 0),]*len(gx.nodes) | |
data = cl.make_data(gx, [targ_fr[0]], an = 2) | |
dataset.append(data) | |
if targ_mr != 0: | |
targ_mr = [max(targ_mr, 0),]*len(gx.nodes) | |
data = cl.make_data(gx, [targ_mr[0]], an = 1) | |
dataset.append(data) | |
print("Number of learning samples:", len(dataset)) | |
pytorch_total_params = sum(p.numel() for p in cl.parameters()) | |
print("Number of network paramters", pytorch_total_params) | |
loss_fn = nn.CrossEntropyLoss() | |
optim = torch.optim.Adam(cl.parameters(), lr=0.001) | |
loader = tgx.data.DataLoader(dataset, batch_size=128) | |
losses = [] | |
acc = [] | |
for i in trange(2): | |
for batch in loader: | |
cl.zero_grad() | |
optim.zero_grad() | |
net_pred = cl(batch) | |
loss_total = loss_fn(net_pred, batch.y) | |
losses.append(loss_total) | |
acc.append(calculate_accuracy(batch.y, net_pred)) | |
loss_total.backward() | |
optim.step() | |
plt.figure(figsize=(10, 6)) | |
plt.title("Losses (batch_size = 128; lr = 0.001)") | |
plt.plot([x.item() for x in losses], label="Loss") | |
plt.plot([x.item() for x in acc], label="Accuracy") | |
plt.xlabel("Number of steps") | |
plt.legend() | |
plt.savefig("status.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For more information about this gist visit my blog, where I tell how to make a PTC drug classifier.