Skip to content

Instantly share code, notes, and snippets.

@queensferryme
Created April 23, 2022 08:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save queensferryme/1526acfd04a29b86fbeeaf5c09db17e1 to your computer and use it in GitHub Desktop.
Save queensferryme/1526acfd04a29b86fbeeaf5c09db17e1 to your computer and use it in GitHub Desktop.
A GraphSAGE implementation with Julia & Flux.jl
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1e2afde0",
"metadata": {},
"outputs": [],
"source": [
"using Flux\n",
"using LinearAlgebra: norm\n",
"using MLDatasets: Cora\n",
"using StatsBase: mean, sample\n",
"\n",
"import Zygote"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21a1dfcf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(node_features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], node_labels = [4, 5, 5, 1, 4, 3, 1, 4, 4, 3 … 4, 4, 4, 4, 1, 4, 4, 4, 4, 4], adjacency_list = [[634, 1863, 2583], [3, 653, 655], [1987, 333, 1667, 2, 1455], [2545], [2177, 1017, 1762, 1257, 2176], [1630, 2547, 1660], [1417, 1603, 1043, 374], [209], [282, 1997, 270], [2615, 724] … [1401, 1574], [2631], [1152], [45, 2625], [187, 1537], [1299], [642], [288], [166, 2708, 1474, 170], [599, 166, 1474, 2707]], train_indices = 1:140, val_indices = 141:640, test_indices = 1709:2708, num_classes = 7, num_nodes = 2708, num_edges = 10556, directed = false)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = Cora.dataset()"
]
},
{
"cell_type": "markdown",
"id": "842a981c",
"metadata": {},
"source": [
"# Linear Regression\n",
"\n",
"Just ignore the graph topology :("
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "02c6dbbb",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"7×1000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n",
" ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅\n",
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n",
" ⋅ 1 1 1 1 ⋅ 1 1 1 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n",
" 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 1 ⋅ 1 1 1 1 1\n",
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n",
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n",
" ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = data.node_features\n",
"y = Flux.onehotbatch(data.node_labels, 1:7)\n",
"\n",
"train_x = x[:, data.train_indices]\n",
"train_y = y[:, data.train_indices]\n",
"\n",
"val_x = x[:, data.val_indices]\n",
"val_y = y[:, data.val_indices]\n",
"\n",
"test_x = x[:, data.test_indices]\n",
"test_y = y[:, data.test_indices]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4f8d91c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Descent(0.1)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acc(x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))\n",
"loss(x, y) = Flux.crossentropy(model(x), y)\n",
"opt = Descent()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "04e64719",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#1#2\"}}, Int64}}) (generic function with 1 method)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Chain(\n",
" Dense(1433, 256),\n",
" Dense(256, 7),\n",
" softmax,\n",
")\n",
"\n",
"es = let f = () -> loss(val_x, val_y)\n",
" Flux.early_stopping(f, 3; init_score=f())\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "125b5a14",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acc(train_x, train_y) = 1.0\n",
"loss(train_x, train_y) = 0.069662094f0\n",
"acc(test_x, test_y) = 0.573\n",
"loss(test_x, test_y) = 1.2368718f0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Epoch 1\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 2\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 3\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 4\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 5\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 6\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 7\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 8\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 9\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 10\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 11\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 12\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 13\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 14\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 15\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 16\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 17\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 18\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 19\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 20\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 21\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 22\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 23\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 24\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 25\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 26\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 27\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 28\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 29\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 30\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 31\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 32\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n"
]
},
{
"data": {
"text/plain": [
"1.2368718f0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = Flux.DataLoader((train_x, train_y), batchsize=20, shuffle=true)\n",
"\n",
"Flux.@epochs 10000 begin\n",
" Flux.train!(loss, Flux.params(model), dataset, opt)\n",
" es() && break\n",
"end\n",
"\n",
"@show acc(train_x, train_y)\n",
"@show loss(train_x, train_y)\n",
"@show acc(test_x, test_y)\n",
"@show loss(test_x, test_y)"
]
},
{
"cell_type": "markdown",
"id": "114bd83f",
"metadata": {},
"source": [
"# GraphSAGE\n",
"\n",
"We first learn the node features with GraphSAGE, then apply linear regression."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7bcc168b",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"K = 2\n",
"h = copy(x)\n",
"\n",
"for _ in 1:K\n",
" h_next = similar(h)\n",
" for i in 1:data.num_nodes\n",
" h_next[:, i] = begin\n",
" h[:, cat(data.adjacency_list[i], [i], dims=1)] |>\n",
" x -> mean(x, dims=2) .|>\n",
" relu\n",
" end\n",
" end\n",
" h = h_next\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "28c80808",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"1433×1000 Matrix{Float32}:\n",
" 0.000845309 0.000986193 0.000986193 … 0.0 0.0 0.0 0.0 0.00588235\n",
" 0.00169062 0.00449764 0.00197239 0.0 0.0 0.0 0.0 0.00588235\n",
" 0.0705631 0.0105343 0.00295858 0.0 0.5 0.0 0.0 0.00588235\n",
" 0.00338123 0.00394477 0.00394477 0.0 0.0 0.0 0.0 0.0\n",
" 0.000845309 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.00588235\n",
" 0.0 0.0 0.0 … 0.5 0.0 0.0 0.0 0.00588235\n",
" 0.00591716 0.0260953 0.0381534 0.0 0.0 0.0 0.0 0.0\n",
" 0.208328 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.00588235\n",
" 0.00253593 0.00295858 0.00295858 0.0 0.0 0.0 0.0 0.0\n",
" 0.000845309 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0166667 0.0 … 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0333333 0.0 0.0 0.0 0.0 0.0\n",
" 0.000845309 0.000986193 0.0114029 0.0 0.0 0.0 0.0 0.0\n",
" ⋮ ⋱ ⋮ \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n",
" 0.00507185 0.00844241 0.00591716 0.0 0.0 0.0 0.0 0.0\n",
" 0.0918166 0.00647003 0.0610876 0.0 0.0 0.0 0.0 0.00588235\n",
" 0.00422654 0.00745622 0.00493097 0.0 0.5 0.0 0.0 0.0\n",
" 0.0 0.00252525 0.0238095 … 0.5 0.0 0.0 0.28 0.291765\n",
" 0.0968885 0.0452155 0.0336715 0.0 0.0 0.0 0.0 0.0\n",
" 0.0229441 0.00295858 0.00295858 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n",
" 0.00169062 0.102604 0.00197239 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0\n",
" 0.0688725 0.000986193 0.000986193 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.00588235"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_h = h[:, data.train_indices]\n",
"val_h = h[:, data.val_indices]\n",
"test_h = h[:, data.test_indices]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0cb53e10",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#5#6\"}}, Int64}}) (generic function with 1 method)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Chain(\n",
" Dense(1433, 256),\n",
" Dense(256, 7),\n",
" softmax,\n",
")\n",
"\n",
"es = let f = () -> loss(val_h, val_y)\n",
" Flux.early_stopping(f, 3; init_score=f())\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2b479432",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acc(train_h, train_y) = 1.0\n",
"loss(train_h, train_y) = 0.112696424f0\n",
"acc(test_h, test_y) = 0.814\n",
"loss(test_h, test_y) = 0.5885607f0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Epoch 1\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 2\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 3\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 4\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 5\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 6\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 7\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 8\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 9\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 10\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 11\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 12\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 13\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 14\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 15\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 16\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 17\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 18\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 19\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 20\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 21\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 22\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 23\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 24\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 25\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 26\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 27\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 28\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 29\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 30\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 31\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 32\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 33\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 34\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 35\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 36\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 37\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 38\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 39\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 40\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 41\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 42\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 43\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 44\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 45\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 46\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 47\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 48\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 49\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 50\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 51\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n"
]
},
{
"data": {
"text/plain": [
"0.5885607f0"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = Flux.DataLoader((train_h, train_y), batchsize=20, shuffle=true)\n",
"\n",
"Flux.@epochs 10000 begin\n",
" Flux.train!(loss, Flux.params(model), dataset, opt)\n",
" es() && break\n",
"end\n",
"\n",
"@show acc(train_h, train_y)\n",
"@show loss(train_h, train_y)\n",
"@show acc(test_h, test_y)\n",
"@show loss(test_h, test_y)"
]
},
{
"cell_type": "markdown",
"id": "55697143",
"metadata": {},
"source": [
"# GraphSAGE w/ Mini-Batch\n",
"\n",
"Implement a trainable `GraphSAGE` layer with mini-batch."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c045b751",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3-element Vector{Matrix{Float32}}:\n",
" [0.0 0.0; 0.0 0.0; … ; 0.0 0.0; 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"struct GraphSampler\n",
" K::Integer\n",
" n::Integer\n",
" \n",
" GraphSampler(; K::Integer, n::Integer) = new(K, n)\n",
"end\n",
"\n",
"function (g::GraphSampler)(x::AbstractArray)\n",
" Zygote.ignore() do\n",
" ids = x\n",
" layers = [data.node_features[:, ids]]\n",
" for k in 1:g.K\n",
" ids = vcat([sample(\n",
" data.adjacency_list[u], g.n,\n",
" replace=length(data.adjacency_list[u]) < g.n\n",
" ) for u in ids]...)\n",
" push!(layers, data.node_features[:, ids])\n",
" end\n",
" \n",
" layers\n",
" end\n",
"end\n",
"\n",
"layers = GraphSampler(K=2, n=3)(1:2)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c7bc63b3",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"2-element Vector{Matrix{Float32}}:\n",
" [0.0 0.18969725; 0.13856356 0.088813454; … ; 0.0 0.0; 0.09959424 0.035386723]\n",
" [0.0 0.0 … 0.03308655 0.0; 0.040759385 0.0 … 0.0 0.058238275; … ; 0.0 0.041052293 … 0.010543101 0.030326217; 0.10880104 0.22266395 … 0.2461675 0.16641618]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"struct MeanAggregator\n",
" w::Dense\n",
" \n",
" MeanAggregator((in, out)::Pair{<:Integer, <:Integer}, σ::F) where {F} = new(Dense(in * 2 => out, σ))\n",
"end\n",
"\n",
"function (m::MeanAggregator)(layers::Vector{Matrix{Float32}})\n",
" [\n",
" cat(\n",
" layers[i],\n",
" layers[i + 1] |>\n",
" x -> reshape(x, (size(layers[i], 1), :, size(layers[i], 2))) |>\n",
" x -> mean(x, dims=2)[:, 1, :],\n",
" dims=1\n",
" ) |> m.w\n",
" for i in 1:length(layers) - 1\n",
" ]\n",
"end\n",
"\n",
"Flux.@functor MeanAggregator\n",
"\n",
"MeanAggregator(1433 => 512, relu)(layers)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "140f8839",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(::Flux.var\"#on_trigger#134\"{Flux.var\"#on_trigger#133#135\"{Flux.var\"#137#139\"{Flux.var\"#137#138#140\"{typeof(-), Int64, var\"#23#24\"}}, Int64}}) (generic function with 1 method)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Chain(\n",
" GraphSampler(K=2, n=8),\n",
" MeanAggregator(1433 => 512, relu),\n",
" MeanAggregator(512 => 256, relu),\n",
" x -> x[1],\n",
" Dense(256, 7),\n",
" softmax\n",
")\n",
"\n",
"es = let f = () -> loss(data.val_indices, val_y)\n",
" Flux.early_stopping(f, 3; init_score=f())\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "58804b66",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acc(data.train_indices, train_y) = 1.0\n",
"loss(data.train_indices, train_y) = 0.013695437f0\n",
"acc(data.test_indices, test_y) = 0.782\n",
"loss(data.test_indices, test_y) = 0.70231277f0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Epoch 1\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 2\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 3\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 4\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 5\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 6\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 7\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 8\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 9\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 10\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 11\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 12\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 13\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 14\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 15\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 16\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 17\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 18\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 19\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 20\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 21\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 22\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 23\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 24\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 25\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 26\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 27\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 28\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 29\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 30\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 31\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 32\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n",
"┌ Info: Epoch 33\n",
"└ @ Main /Users/queensferry/.julia/packages/Flux/18YZE/src/optimise/train.jl:153\n"
]
},
{
"data": {
"text/plain": [
"0.70231277f0"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = Flux.DataLoader((data.train_indices, train_y), batchsize=20, shuffle=true)\n",
"\n",
"Flux.@epochs 10000 begin\n",
" Flux.train!(loss, Flux.params(model), dataset, opt)\n",
" es() && break\n",
"end\n",
"\n",
"@show acc(data.train_indices, train_y)\n",
"@show loss(data.train_indices, train_y)\n",
"@show acc(data.test_indices, test_y)\n",
"@show loss(data.test_indices, test_y)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.7.2",
"language": "julia",
"name": "julia-1.7"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment