Created
February 14, 2023 07:26
-
-
Save theabc50111/3ca708d0c1101d57b6172bd717302710 to your computer and use it in GitHub Desktop.
a composite model composed of pytorch and torch-geometric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Linear, GRU, Sequential, BatchNorm1d, ReLU, Dropout | |
import torch.nn.functional as F | |
import torch_geometric | |
from torch_geometric.data import Data, DataLoader | |
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, global_add_pool, summary | |
import pickle | |
from itertools import islice | |
# the pickle file can be download in here: https://drive.google.com/drive/folders/1_KMwCzf1diwS4gGNdSSxG7bnemqQkFxI?usp=sharing | |
with open('tmp_torch_graph_dataset.pickle', 'rb') as handle: | |
train_dataset = pickle.load(handle) | |
train_loader = DataLoader(train_dataset, batch_size = 12, shuffle=False) | |
class GinEncoder(torch.nn.Module): | |
def __init__(self): | |
super(GinEncoder, self).__init__() | |
self.gin_convs = torch.nn.ModuleList() | |
self.gin_convs.append(GINConv(Sequential(Linear(1, 4), | |
BatchNorm1d(4), ReLU(), | |
Linear(4, 4), ReLU()))) | |
self.gin_convs.append(GINConv(Sequential(Linear(4, 4), | |
BatchNorm1d(4), ReLU(), | |
Linear(4, 4), ReLU()))) | |
def forward(self, x, edge_index, batch_node_id): | |
# Node embeddings | |
nodes_emb_layers = [] | |
for i in range(2): | |
x = self.gin_convs[i](x, edge_index) | |
nodes_emb_layers.append(x) | |
# Graph-level readout | |
nodes_emb_pools = [global_add_pool(nodes_emb, batch_node_id) for nodes_emb in nodes_emb_layers] | |
# Concatenate and form the graph embeddings | |
graph_embeds = torch.cat(nodes_emb_pools, dim=1) | |
return graph_embeds | |
def get_embeddings(self, x, edge_index, batch_node_id): | |
with torch.no_grad(): | |
graph_embeds = self.forward(x, edge_index, batch_node_id).reshape(-1) | |
return graph_embeds | |
class MainModel(torch.nn.Module): | |
def __init__(self, graph_encoder:torch.nn.Module): | |
super(MainModel, self).__init__() | |
self.graph_encoder = graph_encoder | |
self.lin1 = Linear(8, 4) | |
self.lin2 = Linear(4, 8) | |
def forward(self, x, edge_index, batch_node_id): | |
graph_embeds = self.graph_encoder(x, edge_index, batch_node_id) | |
out_lin1 = self.lin1(graph_embeds) | |
pred = self.lin2(out_lin1)[-1] | |
return pred | |
gin_encoder = GinEncoder().to("cuda") | |
model = MainModel(gin_encoder).to("cuda") | |
criterion = torch.nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters()) | |
epochs = 10 | |
for epoch_i in range(epochs): | |
model.train() | |
train_loss = 0 | |
for batch_i, data in enumerate(train_loader): | |
data.to("cuda") | |
x, x_edge_index, x_batch_node_id = data.x, data.edge_index, data.batch | |
y, y_edge_index, y_batch_node_id = data.y[-1].x, data.y[-1].edge_index, torch.zeros(data.y[-1].x.shape[0], dtype=torch.int64).to("cuda") | |
optimizer.zero_grad() | |
graph_embeds_pred = model(x, x_edge_index, x_batch_node_id) | |
y_graph_embeds = model.graph_encoder.get_embeddings(y, y_edge_index, y_batch_node_id) | |
loss = criterion(graph_embeds_pred, y_graph_embeds) | |
train_loss += loss | |
loss.backward() | |
optimizer.step() | |
if batch_i == 0: | |
print(f"NO. {epoch_i} EPOCH") | |
print(f"MainModel weights in epoch_{epoch_i}_batch0:{next(islice(model.parameters(), 15, 16))}", end="\n\n") | |
print(f"GinEncoder weights in epoch_{epoch_i}_batch0:{next(model.graph_encoder.parameters())}") | |
print("*"*80) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment