Skip to content

Instantly share code, notes, and snippets.

@yzh119
Last active October 12, 2020 07:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yzh119/d78d5f21075b8362d6dbb6b12f4b5382 to your computer and use it in GitHub Desktop.
Save yzh119/d78d5f21075b8362d6dbb6b12f4b5382 to your computer and use it in GitHub Desktop.
Training GraphSAGE w/ fp16 in DGL.
"""Training graphsage w/ fp16.
Usage:
python train_full.py --gpu 0 --fp16 --dataset
Note that GradScaler is not acitvated because the model successfully converges
without gradient scaling.
DGL's Message Passing APIs are not compatible with fp16 yet, hence we disabled
autocast when calling these APIs (e.g. apply_edges, update_all), see
https://github.com/yzh119/sage-fp16.git
In the default setting, using fp16 saves around 1GB GPU memory (from 4052mb
to 3042mb).
"""
import argparse
import time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
use_fp16=False,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = in_feats, in_feats
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.use_fp16 = use_fp16
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def forward(self, graph, feat):
with graph.local_scope():
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
graph.srcdata['h'] = feat_src
if self.use_fp16:
with torch.cuda.amp.autocast(enabled=False):
graph.srcdata['h'] = graph.srcdata['h'].float()
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
else:
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
# GraphSAGE GCN does not require fc_self.
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
class GraphSAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type,
use_fp16):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.dropout = nn.Dropout(dropout)
self.activation = activation
# input layer
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, use_fp16=use_fp16))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, use_fp16=use_fp16))
# output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, use_fp16=use_fp16)) # activation None
def forward(self, graph, inputs):
h = self.dropout(inputs)
for l, layer in enumerate(self.layers):
h = layer(graph, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def evaluate(model, graph, features, labels, nid):
model.eval()
with torch.no_grad():
logits = model(graph, features)
logits = logits[nid]
labels = labels[nid]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
g = data[0]
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
in_feats = features.shape[1]
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item()))
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
print("use cuda:", args.gpu)
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze()
# graph preprocess and calculate normalization factor
g = dgl.remove_self_loop(g)
n_edges = g.number_of_edges()
if cuda:
g = g.int().to(args.gpu)
# create GraphSAGE model
model = GraphSAGE(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout,
args.aggregator_type,
args.fp16)
if cuda:
model.cuda()
if args.fp16:
from torch.cuda.amp import GradScaler, autocast
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
#if args.fp16:
# scaler = GradScaler()
# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
optimizer.zero_grad()
# forward
if args.fp16:
with autocast():
logits = model(g, features)
loss = F.cross_entropy(logits[train_nid], labels[train_nid])
else:
logits = model(g, features)
loss = F.cross_entropy(logits[train_nid], labels[train_nid])
#if args.fp16:
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
#else:
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(model, g, features, labels, val_nid)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f} | mem {:.2f} MB".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000, torch.cuda.max_memory_allocated() / 1024 / 1024))
print()
acc = evaluate(model, g, features, labels, test_nid)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GraphSAGE')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--fp16", action='store_true')
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=512,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--aggregator-type", type=str, default="gcn",
help="Aggregator type: mean/gcn/pool/lstm")
args = parser.parse_args()
print(args)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment