Skip to content

Instantly share code, notes, and snippets.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
get_default_sharders,
)
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.nn.parallel import DistributedDataParallel
class SplitWrapper(torch.nn.Module):
def __init__(self, mod : torch.nn.Module, next_device: int):
super().__init__()
self.next_device = next_device
def forward(self, *args, **kwargs):
# see recursive_to example: https://github.com/pytorch/pytorch/blob/8cc7221a6517dca314517a2771aff1ce35c3787e/torch/nn/parallel/distributed.py#L1054-L1100
return recursive_to(self.mod(*args, **kwargs), next_device)
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
l1 = torch.nn.Linear(4, 5)
l2 = torch.nn.Linear(5, 6)
self.seq = torch.nn.Sequential(l1, l2)
from datetime import datetime
import os
import torch.multiprocessing as mp
GLOBAL_BATCH_SIZE = 2048
BLOCK_SIZE = 256
def rank0_print(rank, msg):
if rank == 0:
import argparse
import importlib
import os
import submitit
import uuid
from pathlib import Path
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os
# see torchrun tutorial: https://pytorch.org/docs/stable/elastic/run.html
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import argparse
import os
import argparse
import os
import socket
import threading
import subprocess
import time
from typing import Tuple
import torch
import torch.distributed as dist