Skip to content

Instantly share code, notes, and snippets.

@cocomoff
Created October 25, 2023 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cocomoff/144b295a8ff2dd57482a9087cbc476d2 to your computer and use it in GitHub Desktop.
Save cocomoff/144b295a8ff2dd57482a9087cbc476d2 to your computer and use it in GitHub Desktop.
Smallest running example of GCN-like processing using PyTorch Geometric
import networkx as nx
import torch
import torch_geometric
from torch.nn import Linear, Parameter
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, from_networkx
def get_graph() -> nx.Graph:
n = 6
nxg = nx.Graph()
nxg.add_nodes_from(range(n))
E = [(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (4, 5)]
for u, v in E:
nxg.add_edge(u, v)
return nxg
def get_fixed_feature() -> torch.FloatTensor:
X = torch.tensor(
[
[0, 0, 0], # 0
[0, 0, 1], # 1
[0, 1, 0], # 2
[0, 1, 1], # 3
[1, 0, 0], # 4
[1, 1, 0], # 5
]
).float()
return X
class SimpleGNN(MessagePassing):
def __init__(self, in_channels, out_channels):
# in_channels: ノードの入力特徴量の次元 (d).
# out_channels: ノードの出力特徴量の次元 (d')
# メッセージの処理は加えるだけ (add)
super().__init__(aggr="add")
# 同じ初期化
self.lin = Linear(in_channels, out_channels, bias=False)
self.lin.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
def forward(self, x, edge_index):
# 自己ループを加える
# - edge_index: は Dataで取得した(src, dst)の表現
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Wx の計算 (すべてのノードの特徴ベクトルがWで変換される)
x = self.lin(x)
# messageを呼ぶ
out = self.propagate(edge_index, x=x, norm=None)
return out
def message(self, x_j, norm):
# 何もしない (addに任せるだけ)
return x_j
if __name__ == "__main__":
G = get_graph()
X = get_fixed_feature()
print(G)
print(X)
# networkx のデータから PyG のデータを作成する
data = from_networkx(G)
# edge_index は辺情報を格納している
print(data.edge_index)
# 例題
W = Linear(3, 2, bias=False)
W.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
for i in G.nodes():
sum_i = W(X[i, :]) # 自分自身i
for j in G.neighbors(i):
sum_i += W(X[j, :]) # 隣接ノードj
print("ノード:", i, " 手計算:", sum_i.data)
# 例題 その2
gnn = SimpleGNN(3, 2)
out = gnn(X, data.edge_index)
for i in G.nodes():
# j -> i
sum_i = gnn.lin(X[i, :])
for j in G.neighbors(i):
sum_i += gnn.lin(X[j, :])
print("ノード:", i, " 手計算:", sum_i.data, " GNN計算:", out[i].data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment