Skip to content

Instantly share code, notes, and snippets.

View woshiyyya's full-sized avatar
zzz

Yunxuan Xiao woshiyyya

zzz
View GitHub Profile
@woshiyyya
woshiyyya / adag.py
Last active September 17, 2024 16:39
Use adag to train a llama2-7b model with zero bubble pipeline parallel
def generate_zbh1_dag(workers, num_microbatches):
num_workers = len(workers)
num_lead_microbatches = num_workers
with InputNode() as inp:
fwd_queues = [[] for _ in range(num_workers)]
bwd_queues = [[] for _ in range(num_workers)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)]
import ray
import ray.cluster_utils
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.dag import InputNode, MultiOutputNode
from typing import Optional
from ray.dag.compiled_dag_node import CompiledDAG
from argparse import ArgumentError, ArgumentParser
@woshiyyya
woshiyyya / error.log
Last active September 6, 2024 18:13
zbh1 debug
Traceback (most recent call last):
File "/home/ray/default/skeleton_zb_h1.py", line 106, in <module>
ray.get(dag.execute(1))
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2648, in get
return object_refs.get(timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 90, in get
@woshiyyya
woshiyyya / dtensor_2d_llama.py
Created September 5, 2024 01:05
TP + FSDP with PyTorch DTensor
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from log_utils import rank_log, get_logger, verify_min_gpu_count
# ---- GPU check ------------
@woshiyyya
woshiyyya / test.py
Created August 22, 2024 21:22
ADAG hide the actual method error stack trace, but printing a timeout error
import ray
from ray.air.util.torch_dist import _init_torch_distributed
from ray.air._internal.util import find_free_port
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
@woshiyyya
woshiyyya / channel_error.py
Created August 20, 2024 00:28
DAG NCCL channel error when binding with a node of the same actor
import ray
import torch
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
@ray.remote(num_gpus=1)
class MyActor:
def __init__(self):
pass
@woshiyyya
woshiyyya / train_tp_timeout.py
Created August 19, 2024 23:54
DistMM DAG Timeout Failure
import ray
from ray.air.util.torch_dist import _init_torch_distributed
from ray.air._internal.util import find_free_port
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
@woshiyyya
woshiyyya / benchmark_adag.py
Last active August 30, 2024 23:57
Benchmark NCCL Data Transfer
import ray
import torch
from ray.experimental.channel.torch_tensor_type import TorchTensorType
# shape = (4, 8192)
shape = (4, 24576)
@ray.remote(num_gpus=1)
class MyActor:
@woshiyyya
woshiyyya / test_case_1.py
Last active July 18, 2024 22:47
Train Dashboard BugBash
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig
@woshiyyya
woshiyyya / PP.md
Last active August 1, 2024 21:09
pseudo code for PP with Ray DAG