Created
April 23, 2022 08:42
-
-
Save queensferryme/1526acfd04a29b86fbeeaf5c09db17e1 to your computer and use it in GitHub Desktop.
A GraphSAGE implementation with Julia & Flux.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "1e2afde0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Flux\n", | |
"using LinearAlgebra: norm\n", | |
"using MLDatasets: Cora\n", | |
"using StatsBase: mean, sample\n", | |
"\n", | |
"import Zygote" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "21a1dfcf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(node_features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], node_labels = [4, 5, 5, 1, 4, 3, 1, 4, 4, 3 … 4, 4, 4, 4, 1, 4, 4, 4, 4, 4], adjacency_list = [[634, 1863, 2583], [3, 653, 655], [1987, 333, 1667, 2, 1455], [2545], [2177, 1017, 1762, 1257, 2176], [1630, 2547, 1660], [1417, 1603, 1043, 374], [209], [282, 1997, 270], [2615, 724] … [1401, 1574], [2631], [1152], [45, 2625], [187, 1537], [1299], [642], [288], [166, 2708, 1474, 170], [599, 166, 1474, 2707]], train_indices = 1:140, val_indices = 141:640, test_indices = 1709:2708, num_classes = 7, num_nodes = 2708, num_edges = 10556, directed = false)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = Cora.dataset()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "842a981c", | |
"metadata": {}, | |
"source": [ | |
"# Linear Regression\n", | |
"\n", | |
"Just ignore the graph topology :(" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "02c6dbbb", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7×1000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n", | |
" ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅\n", | |
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", | |
" ⋅ 1 1 1 1 ⋅ 1 1 1 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", | |
" 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ 1 1 1 1 1\n", | |
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", | |
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", | |
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = data.node_features\n", | |
"y = Flux.onehotbatch(data.node_labels, 1:7)\n", | |
"\n", | |
"train_x = x[:, data.train_indices]\n", | |
"train_y = y[:, data.train_indices]\n", | |
"\n", | |
"val_x = x[:, data.val_indices]\n", | |
"val_y = y[:, data.val_indices]\n", | |
"\n", | |
"test_x = x[:, data.test_indices]\n", | |
"test_y = y[:, data.test_indices]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "4f8d91c9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Descent(0.1)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"acc(x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))\n", | |
"loss(x, y) = Flux.crossentropy(model(x), y)\n", | |
"opt = Descent()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "04e64719", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#1#2\"}}, Int64}}) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = Chain(\n", | |
" Dense(1433, 256),\n", | |
" Dense(256, 7),\n", | |
" softmax,\n", | |
")\n", | |
"\n", | |
"es = let f = () -> loss(val_x, val_y)\n", | |
" Flux.early_stopping(f, 3; init_score=f())\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "125b5a14", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"acc(train_x, train_y) = 1.0\n", | |
"loss(train_x, train_y) = 0.069662094f0\n", | |
"acc(test_x, test_y) = 0.573\n", | |
"loss(test_x, test_y) = 1.2368718f0\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: Epoch 1\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 2\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 3\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 4\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 5\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 6\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 7\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 8\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 9\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 10\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 11\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 12\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 13\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 14\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 15\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 16\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 17\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 18\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 19\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 20\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 21\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 22\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 23\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 24\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 25\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 26\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 27\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 28\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 29\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 30\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 31\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 32\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"1.2368718f0" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset = Flux.DataLoader((train_x, train_y), batchsize=20, shuffle=true)\n", | |
"\n", | |
"Flux.@epochs 10000 begin\n", | |
" Flux.train!(loss, Flux.params(model), dataset, opt)\n", | |
" es() && break\n", | |
"end\n", | |
"\n", | |
"@show acc(train_x, train_y)\n", | |
"@show loss(train_x, train_y)\n", | |
"@show acc(test_x, test_y)\n", | |
"@show loss(test_x, test_y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "114bd83f", | |
"metadata": {}, | |
"source": [ | |
"# GraphSAGE\n", | |
"\n", | |
"We first learn the node features with GraphSAGE, then apply linear regression." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "7bcc168b", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"K = 2\n", | |
"h = copy(x)\n", | |
"\n", | |
"for _ in 1:K\n", | |
" h_next = similar(h)\n", | |
" for i in 1:data.num_nodes\n", | |
" h_next[:, i] = begin\n", | |
" h[:, cat(data.adjacency_list[i], [i], dims=1)] |>\n", | |
" x -> mean(x, dims=2) .|>\n", | |
" relu\n", | |
" end\n", | |
" end\n", | |
" h = h_next\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "28c80808", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1433×1000 Matrix{Float32}:\n", | |
" 0.000845309 0.000986193 0.000986193 … 0.0 0.0 0.0 0.0 0.00588235\n", | |
" 0.00169062 0.00449764 0.00197239 0.0 0.0 0.0 0.0 0.00588235\n", | |
" 0.0705631 0.0105343 0.00295858 0.0 0.5 0.0 0.0 0.00588235\n", | |
" 0.00338123 0.00394477 0.00394477 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.000845309 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.00588235\n", | |
" 0.0 0.0 0.0 … 0.5 0.0 0.0 0.0 0.00588235\n", | |
" 0.00591716 0.0260953 0.0381534 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.208328 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.00588235\n", | |
" 0.00253593 0.00295858 0.00295858 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.000845309 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0 0.0166667 0.0 … 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0 0.0 0.0333333 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.000845309 0.000986193 0.0114029 0.0 0.0 0.0 0.0 0.0\n", | |
" ⋮ ⋱ ⋮ \n", | |
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.00507185 0.00844241 0.00591716 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0918166 0.00647003 0.0610876 0.0 0.0 0.0 0.0 0.00588235\n", | |
" 0.00422654 0.00745622 0.00493097 0.0 0.5 0.0 0.0 0.0\n", | |
" 0.0 0.00252525 0.0238095 … 0.5 0.0 0.0 0.28 0.291765\n", | |
" 0.0968885 0.0452155 0.0336715 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0229441 0.00295858 0.00295858 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.00169062 0.102604 0.00197239 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0688725 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.0\n", | |
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.00588235" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_h = h[:, data.train_indices]\n", | |
"val_h = h[:, data.val_indices]\n", | |
"test_h = h[:, data.test_indices]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "0cb53e10", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#5#6\"}}, Int64}}) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = Chain(\n", | |
" Dense(1433, 256),\n", | |
" Dense(256, 7),\n", | |
" softmax,\n", | |
")\n", | |
"\n", | |
"es = let f = () -> loss(val_h, val_y)\n", | |
" Flux.early_stopping(f, 3; init_score=f())\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "2b479432", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"acc(train_h, train_y) = 1.0\n", | |
"loss(train_h, train_y) = 0.112696424f0\n", | |
"acc(test_h, test_y) = 0.814\n", | |
"loss(test_h, test_y) = 0.5885607f0\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: Epoch 1\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 2\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 3\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 4\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 5\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 6\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 7\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 8\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 9\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 10\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 11\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 12\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 13\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 14\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 15\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 16\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 17\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 18\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 19\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 20\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 21\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 22\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 23\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 24\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 25\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 26\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 27\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 28\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 29\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 30\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 31\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 32\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 33\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 34\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 35\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 36\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 37\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 38\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 39\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 40\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 41\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 42\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 43\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 44\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 45\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 46\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 47\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 48\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 49\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 50\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 51\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5885607f0" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset = Flux.DataLoader((train_h, train_y), batchsize=20, shuffle=true)\n", | |
"\n", | |
"Flux.@epochs 10000 begin\n", | |
" Flux.train!(loss, Flux.params(model), dataset, opt)\n", | |
" es() && break\n", | |
"end\n", | |
"\n", | |
"@show acc(train_h, train_y)\n", | |
"@show loss(train_h, train_y)\n", | |
"@show acc(test_h, test_y)\n", | |
"@show loss(test_h, test_y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "55697143", | |
"metadata": {}, | |
"source": [ | |
"# GraphSAGE w/ Mini-Batch\n", | |
"\n", | |
"Implement a trainable `GraphSAGE` layer with mini-batch." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "c045b751", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3-element Vector{Matrix{Float32}}:\n", | |
" [0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0]\n", | |
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n", | |
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"struct GraphSampler\n", | |
" K::Integer\n", | |
" n::Integer\n", | |
" \n", | |
" GraphSampler(; K::Integer, n::Integer) = new(K, n)\n", | |
"end\n", | |
"\n", | |
"function (g::GraphSampler)(x::AbstractArray)\n", | |
" Zygote.ignore() do\n", | |
" ids = x\n", | |
" layers = [data.node_features[:, ids]]\n", | |
" for k in 1:g.K\n", | |
" ids = vcat([sample(\n", | |
" data.adjacency_list[u], g.n,\n", | |
" replace=length(data.adjacency_list[u]) < g.n\n", | |
" ) for u in ids]...)\n", | |
" push!(layers, data.node_features[:, ids])\n", | |
" end\n", | |
" \n", | |
" layers\n", | |
" end\n", | |
"end\n", | |
"\n", | |
"layers = GraphSampler(K=2, n=3)(1:2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "c7bc63b3", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2-element Vector{Matrix{Float32}}:\n", | |
" [0.0 0.18969725; 0.13856356 0.088813454; … ; 0.0 0.0; 0.09959424 0.035386723]\n", | |
" [0.0 0.0 … 0.03308655 0.0; 0.040759385 0.0 … 0.0 0.058238275; … ; 0.0 0.041052293 … 0.010543101 0.030326217; 0.10880104 0.22266395 … 0.2461675 0.16641618]" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"struct MeanAggregator\n", | |
" w::Dense\n", | |
" \n", | |
" MeanAggregator((in, out)::Pair{<:Integer, <:Integer}, σ::F) where {F} = new(Dense(in * 2 => out, σ))\n", | |
"end\n", | |
"\n", | |
"function (m::MeanAggregator)(layers::Vector{Matrix{Float32}})\n", | |
" [\n", | |
" cat(\n", | |
" layers[i],\n", | |
" layers[i + 1] |>\n", | |
" x -> reshape(x, (size(layers[i], 1), :, size(layers[i], 2))) |>\n", | |
" x -> mean(x, dims=2)[:, 1, :],\n", | |
" dims=1\n", | |
" ) |> m.w\n", | |
" for i in 1:length(layers) - 1\n", | |
" ]\n", | |
"end\n", | |
"\n", | |
"Flux.@functor MeanAggregator\n", | |
"\n", | |
"MeanAggregator(1433 => 512, relu)(layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "140f8839", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#23#24\"}}, Int64}}) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = Chain(\n", | |
" GraphSampler(K=2, n=8),\n", | |
" MeanAggregator(1433 => 512, relu),\n", | |
" MeanAggregator(512 => 256, relu),\n", | |
" x -> x[1],\n", | |
" Dense(256, 7),\n", | |
" softmax\n", | |
")\n", | |
"\n", | |
"es = let f = () -> loss(data.val_indices, val_y)\n", | |
" Flux.early_stopping(f, 3; init_score=f())\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "58804b66", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"acc(data.train_indices, train_y) = 1.0\n", | |
"loss(data.train_indices, train_y) = 0.013695437f0\n", | |
"acc(data.test_indices, test_y) = 0.782\n", | |
"loss(data.test_indices, test_y) = 0.70231277f0\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: Epoch 1\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 2\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 3\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 4\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 5\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 6\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 7\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 8\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 9\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 10\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 11\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 12\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 13\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 14\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 15\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 16\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 17\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 18\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 19\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 20\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 21\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 22\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 23\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 24\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 25\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 26\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 27\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 28\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 29\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 30\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 31\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 32\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n", | |
"┌ Info: Epoch 33\n", | |
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.70231277f0" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset = Flux.DataLoader((data.train_indices, train_y), batchsize=20, shuffle=true)\n", | |
"\n", | |
"Flux.@epochs 10000 begin\n", | |
" Flux.train!(loss, Flux.params(model), dataset, opt)\n", | |
" es() && break\n", | |
"end\n", | |
"\n", | |
"@show acc(data.train_indices, train_y)\n", | |
"@show loss(data.train_indices, train_y)\n", | |
"@show acc(data.test_indices, test_y)\n", | |
"@show loss(data.test_indices, test_y)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.7.2", | |
"language": "julia", | |
"name": "julia-1.7" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment