Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active July 11, 2024 20:07
Show Gist options
  • Save woshiyyya/d9c371c54d8f0d86ad3ddebc38ae58b7 to your computer and use it in GitHub Desktop.
Save woshiyyya/d9c371c54d8f0d86ad3ddebc38ae58b7 to your computer and use it in GitHub Desktop.
pseudo code for PP with Ray DAG
import ray
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
import torch
import torch.nn as nn
# Each stage is a ray actor that stores a chunk of model
@ray.remote(num_gpus=1)
class PipelineStage:
def __init__(self, pp_rank) -> None:
self.pp_rank = pp_rank
self.model: nn.Module = ... # A chunk of Model
self.input_buffer = dict()
self.output_buffer = dict()
def forward_pass(self, batch_id, input_tensor):
"""Do one forward pass with the input activations from stage n - 1
Return: the output activations.
"""
output_tensor = self.model(input_tensor)
self.input_buffer[batch_id] = input_tensor
self.output_buffer[batch_id] = output_tensor
if is_last_stage:
return calculate_loss(...)
else:
return output_tensor
def backward_pass(self, batch_id, gradient):
"""Do one backward pass with the gradients from stage n + 1
Return: the output gradients.
"""
self.output_buffer[batch_id].backward(gradient)
output_grad = self.input_buffer[batch_id].grad
self.input_buffer.pop(batch_id)
self.output_buffer.pop(batch_id)
return output_grad
def optimizer_step(self):
"""Do one optimizer step to update model with accumulated gradients."""
...
if __name__ == "__main__":
# Initialize Ray remote workers, each with 1 GPU
workers = [PipelineStage.remote(pp_rank) for pp_rank in range(2)]
with InputNode() as input_batches:
batch_1, batch_2 = input_batches
worker_1, worker_2 = workers
dag_1 = worker_1.forward.bind(batch_1)
dag_2 = worker_1.forward.bind(batch_2)
# To use GPU channel, add type hint for the following dag nodes
# dag_* = dag_*.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
dag_1 = worker_2.forward.bind(dag_1)
dag_2 = worker_2.forward.bind(dag_2)
dag_1 = worker_2.backward.bind(dag_1)
dag_2 = worker_2.backward.bind(dag_2)
dag_1 = worker_1.backward.bind(dag_1)
dag_2 = worker_1.backward.bind(dag_2)
dag = MultiOutputNode([dag_1, dag_2])
dag = dag.experimental_compile()
for step in range(MAX_TRAINING_STEPS):
# Get input data
# We use dummy batches here for simplicity. In realistic, you can either:
# - pass real data batches from this driver script
# - or initialize a dataloader in pp_rank=0 worker and fetch data in the `forward_pass()`
input_micro_batches = [torch.FloatTensor(...) for _ in range(2)]
# Execute DAG once to do forward/backward steps for all input batches
dag.execute(input_micro_batches)
# Pipeline Flush
# Do one optimizer step for all workers
ray.get([worker.optimizer_step.remote() for worker in workers])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment