Created
October 25, 2023 14:10
-
-
Save cocomoff/144b295a8ff2dd57482a9087cbc476d2 to your computer and use it in GitHub Desktop.
Smallest running example of GCN-like processing using PyTorch Geometric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
import torch | |
import torch_geometric | |
from torch.nn import Linear, Parameter | |
from torch_geometric.data import Data | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.utils import add_self_loops, degree, from_networkx | |
def get_graph() -> nx.Graph: | |
n = 6 | |
nxg = nx.Graph() | |
nxg.add_nodes_from(range(n)) | |
E = [(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (4, 5)] | |
for u, v in E: | |
nxg.add_edge(u, v) | |
return nxg | |
def get_fixed_feature() -> torch.FloatTensor: | |
X = torch.tensor( | |
[ | |
[0, 0, 0], # 0 | |
[0, 0, 1], # 1 | |
[0, 1, 0], # 2 | |
[0, 1, 1], # 3 | |
[1, 0, 0], # 4 | |
[1, 1, 0], # 5 | |
] | |
).float() | |
return X | |
class SimpleGNN(MessagePassing): | |
def __init__(self, in_channels, out_channels): | |
# in_channels: ノードの入力特徴量の次元 (d). | |
# out_channels: ノードの出力特徴量の次元 (d') | |
# メッセージの処理は加えるだけ (add) | |
super().__init__(aggr="add") | |
# 同じ初期化 | |
self.lin = Linear(in_channels, out_channels, bias=False) | |
self.lin.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float()) | |
def forward(self, x, edge_index): | |
# 自己ループを加える | |
# - edge_index: は Dataで取得した(src, dst)の表現 | |
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) | |
# Wx の計算 (すべてのノードの特徴ベクトルがWで変換される) | |
x = self.lin(x) | |
# messageを呼ぶ | |
out = self.propagate(edge_index, x=x, norm=None) | |
return out | |
def message(self, x_j, norm): | |
# 何もしない (addに任せるだけ) | |
return x_j | |
if __name__ == "__main__": | |
G = get_graph() | |
X = get_fixed_feature() | |
print(G) | |
print(X) | |
# networkx のデータから PyG のデータを作成する | |
data = from_networkx(G) | |
# edge_index は辺情報を格納している | |
print(data.edge_index) | |
# 例題 | |
W = Linear(3, 2, bias=False) | |
W.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float()) | |
for i in G.nodes(): | |
sum_i = W(X[i, :]) # 自分自身i | |
for j in G.neighbors(i): | |
sum_i += W(X[j, :]) # 隣接ノードj | |
print("ノード:", i, " 手計算:", sum_i.data) | |
# 例題 その2 | |
gnn = SimpleGNN(3, 2) | |
out = gnn(X, data.edge_index) | |
for i in G.nodes(): | |
# j -> i | |
sum_i = gnn.lin(X[i, :]) | |
for j in G.neighbors(i): | |
sum_i += gnn.lin(X[j, :]) | |
print("ノード:", i, " 手計算:", sum_i.data, " GNN計算:", out[i].data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment