Skip to content

Instantly share code, notes, and snippets.

@eahenle
Last active July 8, 2021 05:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eahenle/928db9620b652a4d4684633418b22c6b to your computer and use it in GitHub Desktop.
Save eahenle/928db9620b652a4d4684633418b22c6b to your computer and use it in GitHub Desktop.
Trying to train a variable GNN w/ GeometricFlux.jl
first second
N 2
H 4
O 1
C 3
using CSV, DataFrames, Flux, GeometricFlux, JLD2, MetaGraphs
#####
# DEFINITIONS
#####
# converts a MetaGraph to FeaturedGraph
function mg2fg(mg::MetaGraph, atom_to_int::Dict)
adj_mat = adjacency_matrix(mg)
nb_v = nv(mg)
nb_e = ne(mg)
node_features = zeros(length(atom_to_int)+1, nb_v) # examples are column vectors
for i ∈ 1:nb_v
# TODO record node features
end
edge_features = zeros(nb_e, 1)
for i ∈ 1:nb_e
# TODO record edge features
end
return FeaturedGraph(adj_mat; directed=:undirected, N=nb_v, E=nb_e, nf=node_features, ef=edge_features)
end
# loads the processed inputs according to test-train split
# inputs must be in the form of Dict(
# input_name::String => Dict(
# :graph => g::MetaGraph,
# :target => y::Float64
# )
# )
function data_loader(train_inputs, test_inputs, atom_to_int)
# load the data Dict
@load "graph_dict.jld2" obj
# separate training and testing inputs
training_set = [obj[x] for x ∈ train_inputs]
test_set = [obj[x] for x ∈ test_inputs]
# translate the adjacency and feature matrices into FeaturedGraph objects
training_graphs = [mg2fg(x[:graph], atom_to_int) for x ∈ training_set]
test_graphs = [mg2fg(x[:graph], atom_to_int) for x ∈ test_set]
# unpack targets for each set
training_targets = [x[:target] for x ∈ training_set]
test_targets = [x[:target] for x ∈ test_set]
# zip together the graph inputs with their targets
training_data = zip(training_graphs, training_targets)
test_data = zip(test_graphs, test_targets)
# return the data sets
return training_data, test_data
end
# splits off the tail of the input list to give
function test_train_split(input_list, test_fraction)
n = length(input_list)
last_train_id = floor(n * (1 - test_fraction))
#@assert last_train_id < n "Test fraction $test_fraction too small for $n inputs"
train_ids = Int.(1:last_train_id)
test_ids = Int.((last_train_id + 1):n)
return input_list[train_ids], input_list[test_ids]
end
# constructs the framework of the neural net for training
function build_model(hidden_size, time_steps)
return Chain(GatedGraphConv(hidden_size, time_steps, aggr=:mean), softmax)
end
# trains the neural net
function train_model!(model, training_data, nb_epochs, opt)
loss(x, y) = Flux.mse(model(x), y)
@Flux.epochs nb_epochs Flux.train!(
loss,
Flux.params(model),
training_data,
opt
)
end
#####
# PARAMETERS
#####
opt = ADAM() # gradient optimizer
nb_epochs = 10 # number of training epochs
t = 5 # number of message-passing steps
sz_h = 100 # size of hidden state in GRU
test_fraction = 0 # fraction of inputs to reserve for testing
# message-passing functions
global message(mp, x_i, x_j, e_ij) = x_j
global update(mp, m, x) = m
#####
# DATA LOADING
#####
# read list of input structures
inputs = ["dummy_input"]
# create dict for dummy input
obj = Dict()
g = MetaGraph()
y = 1.0
obj["dummy_input"] = Dict(:graph => g, :target => y)
@save "graph_dict.jld2" obj
# designate test inputs
train_inputs, test_inputs = test_train_split(inputs, test_fraction)
# read atom encoding scheme
df = CSV.read("atom_to_int.csv", DataFrame)
rename!(df, [:atom, :int])
atom_to_int = Dict([row[:atom] => row[:int] for row ∈ eachrow(df)])
# load data
training_data, test_data = data_loader(train_inputs, test_inputs, atom_to_int)
#####
# MODEL TRAINING
#####
# instantiate the model
model = build_model(sz_h, t)
# train the model
train_model!(model, training_data, nb_epochs, opt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment