Last active
July 11, 2024 20:07
-
-
Save woshiyyya/d9c371c54d8f0d86ad3ddebc38ae58b7 to your computer and use it in GitHub Desktop.
pseudo code for PP with Ray DAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
from ray.dag.input_node import InputNode | |
from ray.dag.output_node import MultiOutputNode | |
import torch | |
import torch.nn as nn | |
# Each stage is a ray actor that stores a chunk of model | |
@ray.remote(num_gpus=1) | |
class PipelineStage: | |
def __init__(self, pp_rank) -> None: | |
self.pp_rank = pp_rank | |
self.model: nn.Module = ... # A chunk of Model | |
self.input_buffer = dict() | |
self.output_buffer = dict() | |
def forward_pass(self, batch_id, input_tensor): | |
"""Do one forward pass with the input activations from stage n - 1 | |
Return: the output activations. | |
""" | |
output_tensor = self.model(input_tensor) | |
self.input_buffer[batch_id] = input_tensor | |
self.output_buffer[batch_id] = output_tensor | |
if is_last_stage: | |
return calculate_loss(...) | |
else: | |
return output_tensor | |
def backward_pass(self, batch_id, gradient): | |
"""Do one backward pass with the gradients from stage n + 1 | |
Return: the output gradients. | |
""" | |
self.output_buffer[batch_id].backward(gradient) | |
output_grad = self.input_buffer[batch_id].grad | |
self.input_buffer.pop(batch_id) | |
self.output_buffer.pop(batch_id) | |
return output_grad | |
def optimizer_step(self): | |
"""Do one optimizer step to update model with accumulated gradients.""" | |
... | |
if __name__ == "__main__": | |
# Initialize Ray remote workers, each with 1 GPU | |
workers = [PipelineStage.remote(pp_rank) for pp_rank in range(2)] | |
with InputNode() as input_batches: | |
batch_1, batch_2 = input_batches | |
worker_1, worker_2 = workers | |
dag_1 = worker_1.forward.bind(batch_1) | |
dag_2 = worker_1.forward.bind(batch_2) | |
# To use GPU channel, add type hint for the following dag nodes | |
# dag_* = dag_*.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
dag_1 = worker_2.forward.bind(dag_1) | |
dag_2 = worker_2.forward.bind(dag_2) | |
dag_1 = worker_2.backward.bind(dag_1) | |
dag_2 = worker_2.backward.bind(dag_2) | |
dag_1 = worker_1.backward.bind(dag_1) | |
dag_2 = worker_1.backward.bind(dag_2) | |
dag = MultiOutputNode([dag_1, dag_2]) | |
dag = dag.experimental_compile() | |
for step in range(MAX_TRAINING_STEPS): | |
# Get input data | |
# We use dummy batches here for simplicity. In realistic, you can either: | |
# - pass real data batches from this driver script | |
# - or initialize a dataloader in pp_rank=0 worker and fetch data in the `forward_pass()` | |
input_micro_batches = [torch.FloatTensor(...) for _ in range(2)] | |
# Execute DAG once to do forward/backward steps for all input batches | |
dag.execute(input_micro_batches) | |
# Pipeline Flush | |
# Do one optimizer step for all workers | |
ray.get([worker.optimizer_step.remote() for worker in workers]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment