-
-
Save eahenle/928db9620b652a4d4684633418b22c6b to your computer and use it in GitHub Desktop.
Trying to train a variable GNN w/ GeometricFlux.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
first | second | |
---|---|---|
N | 2 | |
H | 4 | |
O | 1 | |
C | 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CSV, DataFrames, Flux, GeometricFlux, JLD2, MetaGraphs | |
##### | |
# DEFINITIONS | |
##### | |
# converts a MetaGraph to FeaturedGraph | |
function mg2fg(mg::MetaGraph, atom_to_int::Dict) | |
adj_mat = adjacency_matrix(mg) | |
nb_v = nv(mg) | |
nb_e = ne(mg) | |
node_features = zeros(length(atom_to_int)+1, nb_v) # examples are column vectors | |
for i ∈ 1:nb_v | |
# TODO record node features | |
end | |
edge_features = zeros(nb_e, 1) | |
for i ∈ 1:nb_e | |
# TODO record edge features | |
end | |
return FeaturedGraph(adj_mat; directed=:undirected, N=nb_v, E=nb_e, nf=node_features, ef=edge_features) | |
end | |
# loads the processed inputs according to test-train split | |
# inputs must be in the form of Dict( | |
# input_name::String => Dict( | |
# :graph => g::MetaGraph, | |
# :target => y::Float64 | |
# ) | |
# ) | |
function data_loader(train_inputs, test_inputs, atom_to_int) | |
# load the data Dict | |
@load "graph_dict.jld2" obj | |
# separate training and testing inputs | |
training_set = [obj[x] for x ∈ train_inputs] | |
test_set = [obj[x] for x ∈ test_inputs] | |
# translate the adjacency and feature matrices into FeaturedGraph objects | |
training_graphs = [mg2fg(x[:graph], atom_to_int) for x ∈ training_set] | |
test_graphs = [mg2fg(x[:graph], atom_to_int) for x ∈ test_set] | |
# unpack targets for each set | |
training_targets = [x[:target] for x ∈ training_set] | |
test_targets = [x[:target] for x ∈ test_set] | |
# zip together the graph inputs with their targets | |
training_data = zip(training_graphs, training_targets) | |
test_data = zip(test_graphs, test_targets) | |
# return the data sets | |
return training_data, test_data | |
end | |
# splits off the tail of the input list to give | |
function test_train_split(input_list, test_fraction) | |
n = length(input_list) | |
last_train_id = floor(n * (1 - test_fraction)) | |
#@assert last_train_id < n "Test fraction $test_fraction too small for $n inputs" | |
train_ids = Int.(1:last_train_id) | |
test_ids = Int.((last_train_id + 1):n) | |
return input_list[train_ids], input_list[test_ids] | |
end | |
# constructs the framework of the neural net for training | |
function build_model(hidden_size, time_steps) | |
return Chain(GatedGraphConv(hidden_size, time_steps, aggr=:mean), softmax) | |
end | |
# trains the neural net | |
function train_model!(model, training_data, nb_epochs, opt) | |
loss(x, y) = Flux.mse(model(x), y) | |
@Flux.epochs nb_epochs Flux.train!( | |
loss, | |
Flux.params(model), | |
training_data, | |
opt | |
) | |
end | |
##### | |
# PARAMETERS | |
##### | |
opt = ADAM() # gradient optimizer | |
nb_epochs = 10 # number of training epochs | |
t = 5 # number of message-passing steps | |
sz_h = 100 # size of hidden state in GRU | |
test_fraction = 0 # fraction of inputs to reserve for testing | |
# message-passing functions | |
global message(mp, x_i, x_j, e_ij) = x_j | |
global update(mp, m, x) = m | |
##### | |
# DATA LOADING | |
##### | |
# read list of input structures | |
inputs = ["dummy_input"] | |
# create dict for dummy input | |
obj = Dict() | |
g = MetaGraph() | |
y = 1.0 | |
obj["dummy_input"] = Dict(:graph => g, :target => y) | |
@save "graph_dict.jld2" obj | |
# designate test inputs | |
train_inputs, test_inputs = test_train_split(inputs, test_fraction) | |
# read atom encoding scheme | |
df = CSV.read("atom_to_int.csv", DataFrame) | |
rename!(df, [:atom, :int]) | |
atom_to_int = Dict([row[:atom] => row[:int] for row ∈ eachrow(df)]) | |
# load data | |
training_data, test_data = data_loader(train_inputs, test_inputs, atom_to_int) | |
##### | |
# MODEL TRAINING | |
##### | |
# instantiate the model | |
model = build_model(sz_h, t) | |
# train the model | |
train_model!(model, training_data, nb_epochs, opt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment