Skip to content

Instantly share code, notes, and snippets.

@theabc50111
Created February 14, 2023 07:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save theabc50111/3ca708d0c1101d57b6172bd717302710 to your computer and use it in GitHub Desktop.
Save theabc50111/3ca708d0c1101d57b6172bd717302710 to your computer and use it in GitHub Desktop.
a composite model composed of pytorch and torch-geometric
import torch
from torch.nn import Linear, GRU, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, global_add_pool, summary
import pickle
from itertools import islice
# the pickle file can be download in here: https://drive.google.com/drive/folders/1_KMwCzf1diwS4gGNdSSxG7bnemqQkFxI?usp=sharing
with open('tmp_torch_graph_dataset.pickle', 'rb') as handle:
train_dataset = pickle.load(handle)
train_loader = DataLoader(train_dataset, batch_size = 12, shuffle=False)
class GinEncoder(torch.nn.Module):
def __init__(self):
super(GinEncoder, self).__init__()
self.gin_convs = torch.nn.ModuleList()
self.gin_convs.append(GINConv(Sequential(Linear(1, 4),
BatchNorm1d(4), ReLU(),
Linear(4, 4), ReLU())))
self.gin_convs.append(GINConv(Sequential(Linear(4, 4),
BatchNorm1d(4), ReLU(),
Linear(4, 4), ReLU())))
def forward(self, x, edge_index, batch_node_id):
# Node embeddings
nodes_emb_layers = []
for i in range(2):
x = self.gin_convs[i](x, edge_index)
nodes_emb_layers.append(x)
# Graph-level readout
nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers]
# Concatenate and form the graph embeddings
graph_embeds = torch.cat(nodes_emb_pools, dim=1)
return graph_embeds
def get_embeddings(self, x, edge_index, batch_node_id):
with torch.no_grad():
graph_embeds = self.forward(x, edge_index, batch_node_id).reshape(-1)
return graph_embeds
class MainModel(torch.nn.Module):
def __init__(self, graph_encoder:torch.nn.Module):
super(MainModel, self).__init__()
self.graph_encoder = graph_encoder
self.lin1 = Linear(8, 4)
self.lin2 = Linear(4, 8)
def forward(self, x, edge_index, batch_node_id):
graph_embeds = self.graph_encoder(x, edge_index, batch_node_id)
out_lin1 = self.lin1(graph_embeds)
pred = self.lin2(out_lin1)[-1]
return pred
gin_encoder = GinEncoder().to("cuda")
model = MainModel(gin_encoder).to("cuda")
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 10
for epoch_i in range(epochs):
model.train()
train_loss = 0
for batch_i, data in enumerate(train_loader):
data.to("cuda")
x, x_edge_index, x_batch_node_id = data.x, data.edge_index, data.batch
y, y_edge_index, y_batch_node_id = data.y[-1].x, data.y[-1].edge_index, torch.zeros(data.y[-1].x.shape[0], dtype=torch.int64).to("cuda")
optimizer.zero_grad()
graph_embeds_pred = model(x, x_edge_index, x_batch_node_id)
y_graph_embeds = model.graph_encoder.get_embeddings(y, y_edge_index, y_batch_node_id)
loss = criterion(graph_embeds_pred, y_graph_embeds)
train_loss += loss
loss.backward()
optimizer.step()
if batch_i == 0:
print(f"NO. {epoch_i} EPOCH")
print(f"MainModel weights in epoch_{epoch_i}_batch0:{next(islice(model.parameters(), 15, 16))}", end="\n\n")
print(f"GinEncoder weights in epoch_{epoch_i}_batch0:{next(model.graph_encoder.parameters())}")
print("*"*80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment