Last active
April 17, 2024 00:11
-
-
Save bigsnarfdude/0fdba0788aa45ea900aa30b8375d2245 to your computer and use it in GitHub Desktop.
cora dataset from scratch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import tensorflow as tf | |
import networkx as nx | |
from collections import Counter | |
from sklearn.utils import shuffle | |
from sklearn.preprocessing import LabelEncoder | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.utils import to_categorical | |
from tensorflow.keras.layers import Input, Dropout, Dense | |
from tensorflow.keras import Sequential | |
from tensorflow.keras.optimizers import Adam | |
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping | |
from tensorflow.keras.regularizers import l2 | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Dense, Dropout | |
from spektral.layers import GCNConv, GlobalSumPool | |
# line by line build a list of data and edges | |
all_data = [] # content files | |
all_edges = [] # cites files | |
for root,dirs,files in os.walk('./cora'): | |
for file in files: | |
if '.content' in file: | |
with open(os.path.join(root,file),'r') as f: | |
all_data.extend(f.read().splitlines()) | |
elif 'cites' in file: | |
with open(os.path.join(root,file),'r') as f: | |
all_edges.extend(f.read().splitlines()) | |
#Shuffle the data because the raw data is ordered based on the label | |
random_state = 47 | |
all_data = shuffle(all_data,random_state=random_state) | |
#parse the data | |
labels = [] | |
nodes = [] | |
X = [] | |
for i,data in enumerate(all_data): | |
elements = data.split('\t') | |
labels.append(elements[-1]) | |
X.append(elements[1:-1]) | |
nodes.append(elements[0]) | |
X = np.array(X,dtype=int) | |
N = X.shape[0] #the number of nodes | |
F = X.shape[1] #the size of node features | |
print('X shape: ', X.shape) | |
#parse the edge | |
edge_list=[] | |
for edge in all_edges: | |
e = edge.split('\t') | |
edge_list.append((e[0],e[1])) | |
print('\nNumber of nodes (N): ', N) | |
print('\nNumber of features (F) of each node: ', F) | |
print('\nCategories: ', set(labels)) | |
def limit_data(labels,limit=20,val_num=500,test_num=1000): | |
''' | |
Get the index of train, validation, and test data | |
''' | |
label_counter = dict((l, 0) for l in labels) | |
train_idx = [] | |
for i in range(len(labels)): | |
label = labels[i] | |
if label_counter[label]<limit: | |
#add the example to the training data | |
train_idx.append(i) | |
label_counter[label]+=1 | |
#exit the loop once we found 20 examples for each class | |
if all(count == limit for count in label_counter.values()): | |
break | |
#get the indices that do not go to traning data | |
rest_idx = [x for x in range(len(labels)) if x not in train_idx] | |
#get the first val_num | |
val_idx = rest_idx[:val_num] | |
test_idx = rest_idx[val_num:(val_num+test_num)] | |
return train_idx, val_idx, test_idx | |
train_idx, val_idx, test_idx = limit_data(labels) | |
#set the mask | |
train_mask = np.zeros((N,),dtype=bool) | |
train_mask[train_idx] = True | |
val_mask = np.zeros((N,),dtype=bool) | |
val_mask[val_idx] = True | |
test_mask = np.zeros((N,),dtype=bool) | |
test_mask[test_idx] = True | |
print("All Data Distribution: \n{}".format(Counter(labels))) | |
num_classes = len(set(labels)) | |
print('\nNumber of classes: ', num_classes) | |
def encode_label(labels): | |
label_encoder = LabelEncoder() | |
labels = label_encoder.fit_transform(labels) | |
labels = to_categorical(labels) | |
return labels, label_encoder.classes_ | |
labels_encoded, classes = encode_label(labels) | |
#build the graph | |
G = nx.Graph() | |
G.add_nodes_from(nodes) | |
G.add_edges_from(edge_list) | |
#obtain the adjacency matrix (A) | |
A = nx.adjacency_matrix(G) | |
print('Graph info: ', nx.info(G)) | |
channels = 16 | |
dropout = 0.5 | |
L2_regularization = 5e-4 | |
learning_rate = 1e-2 | |
epochs = 200 | |
es_patience = 10 | |
A = GCNConv.preprocess(A).astype('f4') | |
X_in = Input(shape=(F, )) | |
fltr_in = Input((N, ), sparse=True) | |
dropout_1 = Dropout(dropout)(X_in) | |
graph_conv_1 = GCNConv(channels, | |
activation='relu', | |
kernel_regularizer=l2(L2_regularization), | |
use_bias=False)([dropout_1, fltr_in]) | |
dropout_2 = Dropout(dropout)(graph_conv_1) | |
graph_conv_2 = GCNConv(num_classes, | |
activation='softmax', | |
use_bias=False)([dropout_2, fltr_in]) | |
model = Model(inputs=[X_in, fltr_in], outputs=graph_conv_2) | |
optimizer = Adam(lr=learning_rate) | |
model.compile(optimizer=optimizer, | |
loss='categorical_crossentropy', | |
weighted_metrics=['acc']) | |
model.summary() | |
tbCallBack_GCN = tf.keras.callbacks.TensorBoard( | |
log_dir='./Tensorboard_GCN_cora', | |
) | |
callback_GCN = [tbCallBack_GCN] | |
# Train model | |
validation_data = ([X, A], labels_encoded, val_mask) | |
model.fit([X, A], | |
labels_encoded, | |
sample_weight=train_mask, | |
epochs=epochs, | |
batch_size=N, | |
validation_data=validation_data, | |
shuffle=False, | |
callbacks=[ | |
EarlyStopping(patience=es_patience, restore_best_weights=True), | |
tbCallBack_GCN | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment