This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchbiggraph.config import parse_config | |
import attr | |
train_config = parse_config(CONFIG_PATH) | |
train_path = [convert_path(os.path.join(DATA_DIR, FILENAMES['train']))] | |
train_config = attr.evolve(train_config, edge_paths=train_path) | |
from torchbiggraph.train import train | |
train(train_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
edge_paths = [os.path.join(DATA_DIR, name) for name in FILENAMES.values()] | |
from torchbiggraph.converters.import_from_tsv import convert_input_data | |
convert_input_data( | |
CONFIG_PATH, | |
edge_paths, | |
lhs_col=0, | |
rhs_col=1, | |
rel_col=None, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
def convert_path(fname): | |
basename, _ = os.path.splitext(fname) | |
out_dir = basename + '_partitioned' | |
return out_dir | |
def random_split_file(fpath): | |
root = os.path.dirname(fpath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graph_conv_filters = np.concatenate([A_norm, np.matmul(A_norm, A_norm)], axis=0) | |
graph_conv_filters = K.constant(graph_conv_filters) | |
num_filters = 2 | |
model = Sequential() | |
model.add(GraphCNN(Y.shape[1], num_filters, graph_conv_filters, input_shape=(X.shape[1],), activation='elu', | |
kernel_regularizer=l2(5e-4))) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc']) | |
model.summary() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("this filter includes the edges, so it should perform considerably better than before.:") | |
graph_conv_filters = A_norm | |
graph_conv_filters = K.constant(graph_conv_filters) | |
num_filters = 1 | |
model = Sequential() | |
model.add(GraphCNN(Y.shape[1], num_filters, graph_conv_filters, input_shape=(X.shape[1],), activation='elu', | |
kernel_regularizer=l2(5e-4))) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print("this simply ignores the connected edges, thus receives a pretty bad test_acc:") | |
graph_conv_filters = np.eye(A_norm.shape[0]) | |
graph_conv_filters = K.constant(graph_conv_filters) | |
num_filters = 1 | |
model = Sequential() | |
model.add(GraphCNN(Y.shape[1], num_filters, graph_conv_filters, input_shape=(X.shape[1],), activation='elu', | |
kernel_regularizer=l2(5e-4))) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
sys.path.append(os.path.join(os.getcwd(), "keras-deep-graph-learning")) # Adding the submodule to the module search path | |
sys.path.append(os.path.join(os.getcwd(), "keras-deep-graph-learning/examples")) # Adding the submodule to the module search path | |
import numpy as np | |
from examples import utils | |
from keras.layers import Dense, Activation, Dropout | |
from keras.models import Model, Sequential | |
from keras.regularizers import l2 | |
from keras.optimizers import Adam | |
from keras_dgl.layers import GraphCNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_on_weight= np.array([1,1,0]) | |
print("Now we won't do any fancy preprocessing, just basic training.") | |
NUM_FILTERS = 1 | |
graph_conv_filters = A # you may try np.eye(3) | |
graph_conv_filters = K.constant(graph_conv_filters) | |
model = Sequential() | |
model.add(GraphCNN(Y.shape[1], NUM_FILTERS, graph_conv_filters, input_shape=(X.shape[1],), activation='elu', kernel_regularizer=l2(5e-4))) | |
model.add(Activation('softmax')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
sys.path.append(os.path.join(os.getcwd(), "keras-deep-graph-learning")) # Adding the submodule to the module search path | |
sys.path.append(os.path.join(os.getcwd(), "keras-deep-graph-learning/examples")) # Adding the submodule to the module search path | |
import numpy as np | |
from keras.layers import Dense, Activation, Dropout | |
from keras.models import Model, Sequential | |
from keras.regularizers import l2 | |
from keras.optimizers import Adam | |
from keras_dgl.layers import GraphCNN | |
import keras.backend as K |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import s3fs | |
import fire | |
import os | |
class S3CopyMachine(object): | |
"""Copy to S3 via s3fs.""" | |
def to_s3(self, local_bucket, s3_bucket): |