Skip to content

Instantly share code, notes, and snippets.

@yashbonde
Last active July 30, 2021 04:04
Show Gist options
  • Save yashbonde/ab25e34834e71519178b93d3b5a26c19 to your computer and use it in GitHub Desktop.
Save yashbonde/ab25e34834e71519178b93d3b5a26c19 to your computer and use it in GitHub Desktop.
"""simple script to train a simple graph network to train a simple drug classifier
18.04.2020 - @yashbonde"""
from tqdm import trange
import json
import numpy as np
from pysmiles import read_smiles
import networkx as nx
import pandas as pd # to make the final DF
from types import SimpleNamespace
from collections import namedtuple
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import nn
import torch_geometric as tgx
from torch_scatter import scatter_mean
from torch_geometric.data import Data
# network classes
class GNNBlock(nn.Module):
def __init__(self, node_dim, node_dim_hid, edge_dim, edge_dim_hid):
super(GNNBlock, self).__init__()
# edge mlp: [N+N+E, ehid] x [ehid, E]
self.edge_mlp = nn.Sequential(
nn.Linear(2*node_dim + edge_dim, edge_dim_hid),
nn.ReLU(),
nn.Linear(edge_dim_hid, edge_dim)
)
# node mlp1 [N+E, nhid] x [nhid, N]
self.node_mlp1 = nn.Sequential(
nn.Linear(node_dim + edge_dim, node_dim_hid),
nn.ReLU(),
nn.Linear(node_dim_hid, node_dim)
)
self.node_mlp2 = nn.Sequential(
nn.Linear(2 * node_dim, node_dim_hid),
nn.ReLU(),
nn.Linear(node_dim_hid, node_dim)
)
self.reset_parameters()
def reset_parameters(self):
for item in [self.node_mlp1, self.node_mlp2, self.edge_mlp]:
if hasattr(item, 'reset_parameters'):
item.reset_parameters()
def __repr__(self):
return ('{}(\n'
' edge_mlp={},\n'
' node_mlp1={},\n'
')').format(self.__class__.__name__, self.edge_mlp,
self.node_mlp1)
def forward(self, x, edge_index, edge_attr):
"""
data.x: [N, F_x]
data.edge_index: [2, E] with max entry N-1
data.edge_attr: [E, F_e]
"""
row, col = edge_index
# first we perform edge handling
edge_attr = torch.cat([x[row], x[col], edge_attr], dim=-1)
edge_attr = self.edge_mlp(edge_attr) # [E, F_e]
# second we perform node pass
node_out = torch.cat([x[row], edge_attr], dim=-1) # [1, F_x + F_e]
node_out = self.node_mlp1(node_out) # [N, F_x]
node_out = scatter_mean(node_out, col, dim=0,
dim_size=x.size(0)) # [N, F_x]
node_out = torch.cat([x, node_out], dim=-1) # [N, 2*F_x]
node_out = self.node_mlp2(node_out) # [N,F_x]
return node_out, edge_index, edge_attr
def test_sizes(self):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 1, 2, 1, 3],
[1, 0, 2, 1, 3, 1]], dtype=torch.long)
edge_attr = torch.randn(6, 8)
x, edge_index, edge_attr = self.forward(x, edge_index, edge_attr)
print(x.shape, edge_index.shape, edge_attr.shape)
assert x.shape == torch.Size([4, 8])
assert edge_index.shape == torch.Size([2, 6])
assert edge_attr.shape == torch.Size([6, 8])
print("Test Passed")
self.reset_parameters()
class ClassisiferNetwork(nn.Module):
def __init__(self,
num_elements,
num_edges,
node_dim,
node_extra_feat,
edge_dim,
num_classes,
batch_size,
):
"""
num_elements: int
num_edges: int
node_dim: int
node_extra_feat: int
edge_dim: int
num_classes: int
"""
super(ClassisiferNetwork, self).__init__()
self.node_onehot = np.eye(num_elements).astype(np.float32)
self.edge_onehot = np.eye(num_edges).astype(np.float32)
self.hcount_onehot = np.eye(4).astype(np.float32)
self.animal_onehot = np.eye(4).astype(np.float32)
# linear projections
self.node_dense = nn.Sequential(
nn.Linear(num_elements + node_extra_feat, node_dim),
nn.ReLU()
)
self.edge_dense = nn.Sequential(
nn.Linear(num_edges, edge_dim),
nn.ReLU()
)
# gnn block + batch norms
self.gnn1 = GNNBlock(node_dim, node_dim * 2, edge_dim, 2 * edge_dim)
self.bn_n1 = tgx.nn.BatchNorm(node_dim)
self.bn_e1 = tgx.nn.BatchNorm(node_dim)
self.gnn2 = GNNBlock(node_dim, node_dim * 2, edge_dim, 2 * edge_dim)
self.bn_n2 = tgx.nn.BatchNorm(node_dim)
self.bn_e2 = tgx.nn.BatchNorm(node_dim)
self.gcn = tgx.nn.GCNConv(node_dim, node_dim)
self.bn_n3 = tgx.nn.BatchNorm(node_dim)
self.pred_mlp = nn.Linear(node_dim, num_classes)
def make_data(self, gx, y, an):
# get node information --> [N,3]
node_idx = np.array([node[1]["element"]
for node in gx.nodes.data()]).astype(np.int64) # long
node_embeddings = np.array([(
node[1]["charge"],
node[1]["aromatic"],
*self.hcount_onehot[node[1]["hcount"]]
)for node in gx.nodes.data()]).astype(np.float32) # long float
ele_embed = self.node_onehot[node_idx]
x = np.hstack([ele_embed, node_embeddings])
# get edge information --> [2E,F_e] + [2, E]
edge_idx = []
edge_attr = []
for edge in gx.edges.data():
edge_idx.append([edge[0], edge[1]])
edge_idx.append([edge[1], edge[0]])
edge_attr.extend([edge[2]["order"], ]*2) # edge features per edge
edge_idx = np.array(edge_idx).astype(np.int64) # long
edge_attr = np.array(edge_attr).astype(np.int64) # extra features
edge_index = edge_idx.T
edge_attr = self.edge_onehot[edge_attr]
return Data(
x=torch.from_numpy(x),
edge_index=torch.from_numpy(edge_index),
edge_attr=torch.from_numpy(edge_attr),
an_emb=torch.from_numpy(np.expand_dims(
self.animal_onehot[an], axis=0)),
y=torch.tensor(y, dtype=torch.long)
)
def forward(self, gx):
x, edge_index, edge_attr = gx.x, gx.edge_index, gx.edge_attr
# dense
node_out = self.node_dense(x)
edge_out = self.edge_dense(edge_attr)
# gnn_blocks + reduce values
x, edge_index, edge_attr = self.gnn1(node_out, edge_index, edge_out)
x = self.bn_n1(x)
edge_attr = self.bn_e1(edge_attr)
x, edge_index, edge_attr = self.gnn2(x, edge_index, edge_attr)
x = self.bn_n2(x)
edge_attr = self.bn_e2(edge_attr)
node_out = self.gcn(x, edge_index)
max_pooled_graph = tgx.nn.global_max_pool(node_out, gx.batch)
graph_emb = torch.cat([
max_pooled_graph,
gx.an_emb
], dim=-1)
# final classification layer
carcinogenic_pred = self.pred_mlp(max_pooled_graph)
return carcinogenic_pred
# helper functions
def calculate_accuracy(target, pred):
target = torch.squeeze(target)
pred = torch.argmax(F.softmax(torch.squeeze(pred)), dim=1)
check = target == pred
acc = torch.sum(check).float()/check.shape[0]
return acc
def pre_process_molecule(gx):
for node in gx.nodes.data():
node[1].update({
"element": elements2idx[node[1]["element"]],
"aromatic": int(node[1]["aromatic"]),
"hcount": int(node[1]["hcount"])
})
if __name__ == "__main__":
# load the CSV
fdata = pd.read_csv("PTC_pn/merged.csv")
# load graphs
print("Making all graphs ...")
graphs_all = []
for smiles in fdata["smiles"]:
graphs_all.append(read_smiles(smiles))
print("... Done!")
print("Loading Node and Edge Attribute data ...")
# [NODE] load elements
elements = []
for gx in graphs_all:
ele = [x[1]["element"] for x in gx.nodes.data()]
elements.extend(ele)
elements = list(set(elements))
NUM_ELEMENTS = len(elements)
print("Elements: ", elements, NUM_ELEMENTS)
# [NODE] number of attached hydrogen Atoms
num_hcounts = []
for gx in graphs_all:
num_hcounts.extend([x[1]["hcount"] for x in gx.nodes.data()])
num_hcounts = list(set(num_hcounts))
NUM_HCOUNTS = len(num_hcounts)
print("Number of attached Hydrogen atoms:", num_hcounts, NUM_HCOUNTS)
# [NODE] whether Aromatic or not
aromatic = []
for gx in graphs_all:
aromatic.extend([x[1]["aromatic"] for x in gx.nodes.data()])
aromatic = list(set(aromatic))
print("aromatic:", aromatic)
# [NODE] what is the charge on the atom
charge = []
for gx in graphs_all:
charge.extend([x[1]["charge"] for x in gx.nodes.data()])
charge = list(set(charge))
print("charge:", charge)
# [EDGE] information about the bond
bonds = []
for gx in graphs_all:
bond = [x[2]["order"] for x in gx.edges.data()]
bonds.extend(bond)
bonds = list(set(bonds))
NUM_BONDS = len(bonds)
print("information about the bond:", bonds, NUM_BONDS)
# load element2id
# elements2idx = json.load(open("PTC_pn/ele2id.json"))
elements2idx = {"Zn": 0, "Ba": 1, "Sn": 2, "C": 3, "Cl": 4, "K": 5, "S": 6, "N": 7, "Cu": 8, "Na": 9, "F": 10, "Ca": 11, "O": 12, "In": 13, "P": 14, "Pb": 15, "B": 16, "As": 17, "Te": 18, "I": 19, "Br": 20}
print("... Done!")
# preprocess the graphs
print("Preprocessing graphs ...")
for gx in graphs_all:
pre_process_molecule(gx)
print("... Done!")
cl = ClassisiferNetwork(
num_elements = NUM_ELEMENTS,
num_edges = NUM_BONDS,
node_dim = 32,
node_extra_feat = 2 + 4,
edge_dim = 32,
num_classes = 2,
batch_size = 20
)
# create proper dataset mate!
dataset = []
for idx, gx in enumerate(graphs_all):
if idx == 214:
# issue with graph
continue
targ_fm = int(fdata["carcinogenic_FM"][idx])
targ_mm = int(fdata["carcinogenic_MM"][idx])
targ_fr = int(fdata["carcinogenic_FR"][idx])
targ_mr = int(fdata["carcinogenic_MR"][idx])
# make data samples
if targ_fm != 0:
targ_fm = [max(targ_fm, 0),]*len(gx.nodes)
data = cl.make_data(gx, [targ_fm[0]], an = 0)
dataset.append(data)
if targ_mm != 0:
targ_mm = [max(targ_mm, 0),]*len(gx.nodes)
data = cl.make_data(gx, [targ_mm[0]], an = 1)
dataset.append(data)
if targ_fr != 0:
targ_fr = [max(targ_fr, 0),]*len(gx.nodes)
data = cl.make_data(gx, [targ_fr[0]], an = 2)
dataset.append(data)
if targ_mr != 0:
targ_mr = [max(targ_mr, 0),]*len(gx.nodes)
data = cl.make_data(gx, [targ_mr[0]], an = 1)
dataset.append(data)
print("Number of learning samples:", len(dataset))
pytorch_total_params = sum(p.numel() for p in cl.parameters())
print("Number of network paramters", pytorch_total_params)
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cl.parameters(), lr=0.001)
loader = tgx.data.DataLoader(dataset, batch_size=128)
losses = []
acc = []
for i in trange(2):
for batch in loader:
cl.zero_grad()
optim.zero_grad()
net_pred = cl(batch)
loss_total = loss_fn(net_pred, batch.y)
losses.append(loss_total)
acc.append(calculate_accuracy(batch.y, net_pred))
loss_total.backward()
optim.step()
plt.figure(figsize=(10, 6))
plt.title("Losses (batch_size = 128; lr = 0.001)")
plt.plot([x.item() for x in losses], label="Loss")
plt.plot([x.item() for x in acc], label="Accuracy")
plt.xlabel("Number of steps")
plt.legend()
plt.savefig("status.png")
@yashbonde
Copy link
Author

For more information about this gist visit my blog, where I tell how to make a PTC drug classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment