Created
July 27, 2021 23:45
-
-
Save DwyaneShi/6f4b73eb89c96619e21f886b3b2138e5 to your computer and use it in GitHub Desktop.
Code snippets for performance-issue-of-back-propagation-in-using-raysgd/3042
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_subtensor(g_features, g_labels, seeds, input_nodes, device): | |
batch_inputs = g_features[input_nodes].to(device) | |
batch_labels = g_labels[seeds].to(device) | |
return batch_inputs, batch_labels | |
class NeighborSampler(object): | |
def __init__(self, g, fanouts, sample_neighbors, device): | |
self.g = g | |
self.fanouts = fanouts | |
self.sample_neighbors = sample_neighbors | |
self.device = device | |
def sample_blocks(self, seeds): | |
breakdown_metrics = {} | |
start = time.time() | |
seeds = th.LongTensor(np.asarray(seeds)) | |
blocks = [] | |
for fanout in self.fanouts: | |
# For each seed node, sample ``fanout`` neighbors. | |
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True) | |
# Then we compact the frontier into a bipartite graph for message passing. | |
block = dgl.to_block(frontier, seeds) | |
# Obtain the seed nodes for next layer. | |
seeds = block.srcdata[dgl.NID] | |
blocks.insert(0, block) | |
breakdown_metrics[EPOCH_BREAKDOWN_ACC_SAMPLING] = time.time() - start | |
start = time.time() | |
input_nodes = blocks[0].srcdata[dgl.NID] | |
seeds = blocks[-1].dstdata[dgl.NID] | |
batch_features, batch_labels = load_subtensor(self.g_features, self.g_labels, seeds, | |
input_nodes, self.device) | |
if self.device.type != "cpu": | |
blocks = [block.to(self.device, non_blocking=True) for block in blocks] | |
breakdown_metrics[EPOCH_BREAKDOWN_ACC_COPY] = time.time() - start | |
return blocks, batch_features, batch_labels.long(), breakdown_metrics | |
class DistSAGE(nn.Module): | |
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): | |
super().__init__() | |
self.n_layers = n_layers | |
self.n_hidden = n_hidden | |
self.n_classes = n_classes | |
self.layers = nn.ModuleList() | |
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean")) | |
for i in range(1, n_layers - 1): | |
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean")) | |
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean")) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = activation | |
def forward(self, blocks, features): | |
h = features | |
for l, (layer, block) in enumerate(zip(self.layers, blocks)): | |
h = layer(block, h) | |
if l != len(self.layers) - 1: | |
h = self.activation(h) | |
h = self.dropout(h) | |
return h | |
def create_model(args, device, data): | |
train_nid, val_nid, test_nid, in_feats, n_classes, g = data | |
# Define model and optimizer | |
model = DistSAGE( | |
in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout | |
) | |
model = model.to(device) | |
return model | |
class RayGraphDataLoader(DistDataLoader): | |
# Distributed data loader | |
def create_data_loaders(args, device, data): | |
# Unpack data | |
train_nid, val_nid, test_nid, in_feats, n_classes, g = data | |
# Create sampler | |
sampler = NeighborSampler( | |
g, | |
[int(fanout) for fanout in args.fan_out.split(",")], | |
dgl.distributed.sample_neighbors, | |
device, | |
) | |
train_dataloader = RayGraphDataLoader( | |
dataset=train_nid.numpy(), | |
batch_size=args.batch_size, | |
collate_fn=sampler.sample_blocks, | |
shuffle=True, | |
drop_last=False, | |
) | |
# Create DataLoader for constructing blocks | |
val_dataloader = RayGraphDataLoader( | |
dataset=val_nid.numpy(), | |
batch_size=args.batch_size, | |
collate_fn=sampler.sample_blocks, | |
shuffle=True, | |
drop_last=False, | |
) | |
return train_dataloader, val_dataloader | |
def sagemain(ip_config_file, local_partition_file, args, rank): | |
dgl.distributed.initialize(ip_config_file) | |
pb, _, _, _ = load_partition_book(local_partition_file, rank) | |
g = DistGraph(args.graph_name, gpb=pb) | |
train_nid = dgl.distributed.node_split(g.ndata["train_mask"], pb, force_even=True) | |
val_nid = dgl.distributed.node_split(g.ndata["val_mask"], pb, force_even=True) | |
test_nid = dgl.distributed.node_split(g.ndata["test_mask"], pb, force_even=True) | |
local_nid = pb.partid2nids(pb.partid).detach().numpy() | |
logger.info( | |
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format( | |
g.rank(), | |
len(train_nid), | |
len(np.intersect1d(train_nid.numpy(), local_nid)), | |
len(val_nid), | |
len(np.intersect1d(val_nid.numpy(), local_nid)), | |
len(test_nid), | |
len(np.intersect1d(test_nid.numpy(), local_nid)), | |
) | |
) | |
labels = g.ndata["labels"][np.arange(g.number_of_nodes())] | |
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) | |
logger.info("#labels: {}".format(n_classes)) | |
# Pack data | |
in_feats = g.ndata["features"].shape[1] | |
data = train_nid, val_nid, test_nid, in_feats, n_classes, g | |
# Define model and optimizer | |
model = DistSAGE( | |
in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout | |
) | |
return data, model | |
class SageTrainingOperator(TrainingOperator): | |
def setup(self, config): | |
args = config['args'] | |
import socket | |
local_ip = socket.gethostbyname(socket.gethostname()) | |
rank = config[local_ip] | |
address_list = config["address_list"] | |
partition_config = config["partition_config"] | |
data, model = sagemain(address_list, partition_config, args, rank) | |
device = th.device("cpu") | |
loss_fcn = nn.CrossEntropyLoss() | |
loss_fcn = loss_fcn.to(device) | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
self.model, self.optimizer, self.criterion = self.register( | |
models=model, | |
optimizers=optimizer, | |
criterion=loss_fcn) | |
tloader, vloader = create_data_loaders(args, device, data) | |
self.register_data(train_loader=tloader, validation_loader=vloader) | |
def train_epoch(self, iterator, info): | |
"""Runs one standard training pass over the training dataloader. | |
Args: | |
iterator (iter): Iterator over the training data for the entire | |
epoch. This iterator is expected to be entirely consumed. | |
info (dict): Dictionary for information to be used for custom | |
training operations. | |
Returns: | |
A dict of metrics from training. | |
""" | |
if not hasattr(self, "model"): | |
raise RuntimeError("Either set self.model in setup function or " | |
"override this method to implement a custom " | |
"training loop.") | |
model = self.model | |
scheduler = None | |
if hasattr(self, "scheduler"): | |
scheduler = self.scheduler | |
if self.use_tqdm and self.world_rank == 0: | |
desc = "" | |
if info is not None and "epoch_idx" in info: | |
if "num_epochs" in info: | |
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e" | |
else: | |
desc = f"{info['epoch_idx'] + 1}e" | |
# TODO: Implement len for Dataset? | |
total = info[NUM_STEPS] | |
if total is None: | |
if hasattr(iterator, "__len__"): | |
total = len(iterator) | |
_progress_bar = tqdm( | |
total=total, desc=desc, unit="batch", leave=False) | |
if self.use_gpu: | |
th.cuda.reset_peak_memory_stats() | |
epoch_breakdown_metrics = { | |
EPOCH_BREAKDOWN_EPOCH_TIME: 0, | |
EPOCH_BREAKDOWN_GPU_MEM: 0, | |
EPOCH_BREAKDOWN_ACC_SAMPLING: 0, | |
EPOCH_BREAKDOWN_ACC_COPY: 0, | |
EPOCH_BREAKDOWN_ACC_FORWARD: 0, | |
EPOCH_BREAKDOWN_ACC_BACKWARD: 0, | |
EPOCH_BREAKDOWN_ACC_UPDATE: 0, | |
} | |
metric_meters = AverageMeterCollection() | |
model.train() | |
tic = time.time() | |
for batch_idx, batch in enumerate(iterator): | |
batch_info = { | |
"batch_idx": batch_idx, | |
"global_step": self.global_step, | |
} | |
batch_info.update(info) | |
metrics, breakdown_metrics = self.train_batch(batch, batch_info=batch_info) | |
# Add the step breakdown metrics into epoch breakdown metrics | |
epoch_breakdown_metrics = {k: v + breakdown_metrics.get(k, 0) | |
for k, v in epoch_breakdown_metrics.items()} | |
if self.use_tqdm and self.world_rank == 0: | |
_progress_bar.n = batch_idx + 1 | |
postfix = {} | |
if "train_loss" in metrics: | |
postfix.update(loss=metrics["train_loss"]) | |
_progress_bar.set_postfix(postfix) | |
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_BATCH: | |
scheduler.step() | |
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1)) | |
self.global_step += 1 | |
if scheduler and self.scheduler_step_freq == SCHEDULER_STEP_EPOCH: | |
scheduler.step() | |
toc = time.time() | |
epoch_breakdown_metrics[EPOCH_BREAKDOWN_EPOCH_TIME] = toc - tic | |
gpu_mem_alloc = ( | |
th.cuda.max_memory_allocated() / (1024 * 1024) | |
if self.use_gpu | |
else 0 | |
) | |
epoch_breakdown_metrics[EPOCH_BREAKDOWN_GPU_MEM] = gpu_mem_alloc | |
stats = metric_meters.summary() | |
stats.update(epoch_breakdown_metrics) | |
return stats | |
def train_batch(self, batch, batch_info): | |
"""Computes loss and updates the model over one batch. | |
Args: | |
batch: One item of the validation iterator. | |
batch_info (dict): Information dict passed in from ``train_epoch``. | |
Returns: | |
A dictionary of metrics. | |
""" | |
if not hasattr(self, "model"): | |
raise RuntimeError("Either set self.model in setup function or " | |
"override this method to implement a custom " | |
"training loop.") | |
if not hasattr(self, "optimizer"): | |
raise RuntimeError("Either set self.optimizer in setup function " | |
"or override this method to implement a custom " | |
"training loop.") | |
if not hasattr(self, "criterion"): | |
raise RuntimeError("Either set self.criterion in setup function " | |
"or override this method to implement a custom " | |
"training loop.") | |
tic = time.time() | |
model = self.model | |
optimizer = self.optimizer | |
criterion = self.criterion | |
# unpack features into list to support multiple inputs model | |
*features, target, breakdown_metrics = batch | |
# Compute output. | |
start = time.time() | |
with self.timers.record("fwd"): | |
if self.use_fp16_native: | |
with self._amp.autocast(): | |
output = model(*features) | |
loss = criterion(output, target) | |
else: | |
output = model(*features) | |
loss = criterion(output, target) | |
breakdown_metrics[EPOCH_BREAKDOWN_ACC_FORWARD] = time.time() - start | |
# Compute gradients in a backward pass. | |
start = time.time() | |
with self.timers.record("grad"): | |
optimizer.zero_grad() | |
if self.use_fp16_apex: | |
with self._amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
elif self.use_fp16_native: | |
self._amp_scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
breakdown_metrics[EPOCH_BREAKDOWN_ACC_BACKWARD] = time.time() - start | |
# Call step of optimizer to update model params. | |
start = time.time() | |
with self.timers.record("apply"): | |
if self.use_fp16_native: | |
self._amp_scaler.step(optimizer) | |
self._amp_scaler.update() | |
else: | |
optimizer.step() | |
breakdown_metrics[EPOCH_BREAKDOWN_ACC_UPDATE] = time.time() - start | |
toc = time.time() | |
return {"train_loss": loss.item(), NUM_SAMPLES: target.size(0), | |
"throughput (samples/sec)": target.size(0) / (toc - tic) | |
}, breakdown_metrics | |
def validate_batch(self, batch, batch_info): | |
"""Calculates the loss and accuracy over a given batch. | |
Returns: | |
A dict of metrics. | |
""" | |
if not hasattr(self, "model"): | |
raise RuntimeError("Either set self.model in setup function or " | |
"override this method to implement a custom " | |
"training loop.") | |
if not hasattr(self, "criterion"): | |
raise RuntimeError("Either set self.criterion in setup function " | |
"or override this method to implement a custom " | |
"training loop.") | |
model = self.model | |
criterion = self.criterion | |
# unpack features into list to support multiple inputs model | |
*features, target, breakdown_metrics = batch | |
# compute output | |
with self.timers.record("eval_fwd"): | |
if self.use_fp16_native: | |
with self._amp.autocast(): | |
output = model(*features) | |
loss = criterion(output, target) | |
else: | |
output = model(*features) | |
loss = criterion(output, target) | |
_, predicted = th.max(output.data, 1) | |
num_correct = (predicted == target).sum().item() | |
num_samples = target.size(0) | |
return { | |
"val_loss": loss.item(), | |
"val_accuracy": num_correct / num_samples, | |
NUM_SAMPLES: num_samples | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="GraphSage") | |
# add arguments | |
args = parser.parse_args() | |
start = time.time() | |
config = ... | |
scheduler = ... | |
target_ips = scheduler.get_ips(args.num_partitions, local_mode=False) | |
ip_mapping = {} | |
for rank in range(args.num_partitions): | |
print("Creating service node at ip {}".format(target_ips[rank])) | |
pg = scheduler.create_pg(target_ips[rank]) | |
config.add_service_node(rank, pg) | |
# This mapping lets each trainer know which graph rank it should work with. | |
ip_mapping[target_ips[rank]] = rank | |
ray.get(write_ip_config.options(placement_group=pg).remote( | |
target_ips, config.get_ip_config_fullname() | |
)) | |
# Prepare data and launch graph service. | |
config.list_configurations() | |
LaunchDataService(config) | |
LaunchGraphService(config) | |
# Start distributed training. | |
trainer = TorchTrainer( | |
training_operator_cls=SageTrainingOperator, | |
num_workers=args.num_trainers, | |
use_gpu=False, | |
config={ | |
... | |
}, | |
) | |
for epoch in range(args.num_epochs): | |
stats = trainer.train() | |
print(stats) | |
if epoch % args.eval_every == args.eval_every - 1: | |
stats = trainer.validate() | |
print(stats) | |
print("Running time = {}".format(time.time() - start)) | |
trainer.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_subtensor(g_features, g_labels, seeds, input_nodes, device): | |
batch_inputs = g_features[input_nodes].to(device) | |
batch_labels = g_labels[seeds].to(device) | |
return batch_inputs, batch_labels | |
class NeighborSampler(object): | |
def __init__(self, g, fanouts, sample_neighbors, device): | |
self.g = g | |
self.fanouts = fanouts | |
self.sample_neighbors = sample_neighbors | |
self.device = device | |
def sample_blocks(self, seeds): | |
seeds = th.LongTensor(np.asarray(seeds)) | |
blocks = [] | |
for fanout in self.fanouts: | |
# For each seed node, sample ``fanout`` neighbors. | |
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True) | |
# Then we compact the frontier into a bipartite graph for message passing. | |
block = dgl.to_block(frontier, seeds) | |
# Obtain the seed nodes for next layer. | |
seeds = block.srcdata[dgl.NID] | |
blocks.insert(0, block) | |
return blocks | |
class DistSAGE(nn.Module): | |
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): | |
super().__init__() | |
self.n_layers = n_layers | |
self.n_hidden = n_hidden | |
self.n_classes = n_classes | |
self.layers = nn.ModuleList() | |
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean")) | |
for i in range(1, n_layers - 1): | |
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean")) | |
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean")) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = activation | |
def forward(self, blocks, x): | |
h = x | |
for l, (layer, block) in enumerate(zip(self.layers, blocks)): | |
h = layer(block, h) | |
if l != len(self.layers) - 1: | |
h = self.activation(h) | |
h = self.dropout(h) | |
return h | |
def inference(self, g, x, batch_size, device): | |
# inference | |
def compute_acc(pred, labels): | |
labels = labels.long() | |
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) | |
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): | |
model.eval() | |
with th.no_grad(): | |
pred = model.inference(g, inputs, batch_size, device) | |
model.train() | |
return compute_acc(pred[val_nid], labels[val_nid].to(pred.device)), compute_acc( | |
pred[test_nid], labels[test_nid].to(pred.device) | |
) | |
def create_model(args, device, data): | |
train_nid, val_nid, test_nid, in_feats, n_classes, g, g_features, g_labels = data | |
# Define model and optimizer | |
model = DistSAGE( | |
in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout | |
) | |
model = model.to(device) | |
return model | |
def run(args, device, data, model): | |
# Unpack data | |
train_nid, val_nid, test_nid, in_feats, n_classes, g, g_features, g_labels = data | |
# Create sampler | |
sampler = NeighborSampler( | |
g, | |
[int(fanout) for fanout in args.fan_out.split(",")], | |
dgl.distributed.sample_neighbors, | |
device, | |
) | |
# Create DataLoader for constructing blocks | |
dataloader = DistDataLoader( | |
dataset=train_nid.numpy(), | |
batch_size=args.batch_size, | |
collate_fn=sampler.sample_blocks, | |
shuffle=True, | |
drop_last=False, | |
) | |
if not args.standalone: | |
model = th.nn.parallel.DistributedDataParallel(model) | |
loss_fcn = nn.CrossEntropyLoss() | |
loss_fcn = loss_fcn.to(device) | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
train_size = th.sum(g.ndata["train_mask"][0 : g.number_of_nodes()]) | |
# Training loop | |
iter_tput = [] | |
epoch = 0 | |
for epoch in range(args.num_epochs): | |
tic = time.time() | |
sample_time = 0 | |
forward_time = 0 | |
backward_time = 0 | |
update_time = 0 | |
num_seeds = 0 | |
num_inputs = 0 | |
start = time.time() | |
# Loop over the dataloader to sample the computation dependency graph as a list of | |
# blocks. | |
step_time = [] | |
for step, blocks in enumerate(dataloader): | |
tic_step = time.time() | |
sample_time += tic_step - start | |
input_nodes = blocks[0].srcdata[dgl.NID] | |
seeds = blocks[-1].dstdata[dgl.NID] | |
batch_inputs, batch_labels = load_subtensor(g_features, g_labels, seeds, input_nodes, device) | |
batch_labels = batch_labels.long() | |
num_seeds += len(blocks[-1].dstdata[dgl.NID]) | |
num_inputs += len(blocks[0].srcdata[dgl.NID]) | |
blocks = [block.to(device) for block in blocks] | |
batch_labels = batch_labels.to(device) | |
# Compute loss and prediction | |
start = time.time() | |
batch_pred = model(blocks, batch_inputs) | |
loss = loss_fcn(batch_pred, batch_labels) | |
forward_end = time.time() | |
optimizer.zero_grad() | |
loss.backward() | |
compute_end = time.time() | |
forward_time += forward_end - start | |
backward_time += compute_end - forward_end | |
optimizer.step() | |
update_time += time.time() - compute_end | |
step_t = time.time() - tic_step | |
step_time.append(step_t) | |
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t) | |
if step % args.log_every == 0: | |
acc = compute_acc(batch_pred, batch_labels) | |
gpu_mem_alloc = ( | |
th.cuda.max_memory_allocated() / 1000000 | |
if th.cuda.is_available() | |
else 0 | |
) | |
print( | |
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s".format( | |
g.rank(), | |
epoch, | |
step, | |
loss.item(), | |
acc.item(), | |
np.mean(iter_tput[3:]), | |
gpu_mem_alloc, | |
np.sum(step_time[-args.log_every :]), | |
) | |
) | |
start = time.time() | |
toc = time.time() | |
print( | |
"Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format( | |
g.rank(), | |
toc - tic, | |
sample_time, | |
forward_time, | |
backward_time, | |
update_time, | |
num_seeds, | |
num_inputs, | |
) | |
) | |
epoch += 1 | |
if epoch % args.eval_every == 0 and epoch != 0: | |
print("Doing evaluation...", flush=True) | |
start = time.time() | |
val_acc, test_acc = evaluate( | |
model, | |
g, | |
g_features, | |
g_labels, | |
val_nid, | |
test_nid, | |
args.batch_size_eval, | |
device, | |
) | |
print( | |
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format( | |
g.rank(), val_acc, test_acc, time.time() - start | |
) | |
) | |
def main(ip_config, local_partition_file, args, rank): | |
if not args.standalone: | |
th.distributed.init_process_group(backend="gloo") | |
dgl.distributed.initialize(ip_config) | |
_, _, _, pb, _, _, _ = load_partition(local_partition_file, rank) | |
g = DistGraph("ogb-arxiv", gpb=pb) | |
train_nid = dgl.distributed.node_split(g.ndata["train_mask"], pb, force_even=True) | |
val_nid = dgl.distributed.node_split(g.ndata["val_mask"], pb, force_even=True) | |
test_nid = dgl.distributed.node_split(g.ndata["test_mask"], pb, force_even=True) | |
local_nid = pb.partid2nids(pb.partid).detach().numpy() | |
print( | |
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format( | |
g.rank(), | |
len(train_nid), | |
len(np.intersect1d(train_nid.numpy(), local_nid)), | |
len(val_nid), | |
len(np.intersect1d(val_nid.numpy(), local_nid)), | |
len(test_nid), | |
len(np.intersect1d(test_nid.numpy(), local_nid)), | |
) | |
) | |
device = th.device('cpu') | |
labels = g.ndata["labels"][np.arange(g.number_of_nodes())] | |
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) | |
print("#labels:", n_classes) | |
g_features = g.ndata.pop("features")[np.arange(g.num_nodes())] | |
g_labels = g.ndata.pop("labels")[np.arange(g.num_nodes())] | |
# Pack data | |
in_feats = g_features.shape[1] | |
data = train_nid, val_nid, test_nid, in_feats, n_classes, g, g_features, g_labels | |
model = create_model(args, device, data) | |
return (device, data, model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ray.remote | |
class GraphServer(object): | |
def __init__(self, rank, num_partitions, partition_config, ip_config, num_graph_servers_per_machine, | |
total_num_clients): | |
self._rank = rank | |
self._num_partitions = num_partitions | |
self._partition_config = partition_config | |
self._ip_config = ip_config | |
self._num_graph_servers_per_machine = num_graph_servers_per_machine | |
self._total_num_clients = total_num_clients | |
def serve(self): | |
assert (self._num_partitions >= 1), \ | |
"The number of partitions should be at least 1." | |
g = DistGraphServer( | |
self._rank, | |
# TODO: Design a mechanism to replace ipconfig file in ray. | |
self._ip_config, | |
self._num_graph_servers_per_machine, | |
self._total_num_clients, | |
self._partition_config, | |
# TODO: Double check and validate shared memory still works when sever/client are in different actors. | |
disable_shared_mem=True, | |
) | |
g.start() | |
return "Graph Server " + str(self._rank) | |
@ray.remote | |
class Network(object): | |
def __init__(self, args, rank, partition_config, ip_config): | |
self._args = args | |
self._device, self._data, self._model = sage.main(ip_config, partition_config, args, rank) | |
def train(self): | |
sage.run(self._args, self._device, self._data, self._model) | |
return 0 | |
def get_weights(self): | |
return self._model.state_dict() | |
def set_weights(self, weights): | |
print("Propagating weights to parallel models.") | |
self._model.load_state_dict(weights) | |
def save(self): | |
torch.save(self._model.state_dict(), "sage_model.pt") | |
if __name__ == "__main__": | |
t_start = time.time() | |
... | |
print("Launching GraphServers....") | |
svr_list = [] | |
for rank in range(NUM_PARTITIONS): | |
print("Starting graph server {%d}...." % rank) | |
svr = GraphServer.options(placement_group=pg_list[rank], name="GServer" + str(rank), | |
lifetime="detached").remote(rank, NUM_PARTITIONS, partition_config, ip_config, | |
NUM_SERVERS_PER_MACHINE, TOTAL_NUM_CLIENTS) | |
svr_list.append(svr) | |
svr.serve.remote() | |
parser = argparse.ArgumentParser(description="GCN") | |
register_data_args(parser) | |
# add arguments | |
args = parser.parse_args() | |
client_list = [] | |
for rank in range(NUM_PARTITIONS): | |
for count in range(NUM_CLIENTS_PER_MACHINE): | |
client = Network.options(placement_group=pg_list[rank]).remote(args, rank, partition_config, ip_config) | |
client_list.append(client) | |
for i in range(NUM_TRAIN_ITERATIONS): | |
ray.get([client.train.remote() for client in client_list]) | |
# Get the weights from all partitions, which is a list of OrderedDict | |
weights = ray.get([client.get_weights.remote() for client in client_list]) | |
averaged_weights = OrderedDict() | |
for key in weights[0]: | |
weight_list = [weights[rank][key] for rank in range(NUM_PARTITIONS)] | |
averaged_weight = torch.mean(torch.stack(weight_list), dim=0) | |
averaged_weights[key] = averaged_weight | |
weight_id = ray.put(averaged_weights) | |
[client.set_weights.remote(weight_id) for client in client_list] | |
print("Model parameter reduce step ", i) | |
print("Killing training clients....") | |
for client in client_list: | |
ray.kill(client) | |
time.sleep(2) | |
print("Killing GraphServers....") | |
for svr in svr_list: | |
ray.kill(svr) | |
t_end = time.time() | |
print("Running time = " + str(t_end - t_start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment