Skip to content

Instantly share code, notes, and snippets.

View wookim3's full-sized avatar

Woo Kim wookim3

  • Facebook AI
View GitHub Profile
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
rref = rpc.remote("worker1", torch.add,
args=(t1, t2))
ddp_model = DDP(my_model)
# Setup optimizer
optimizer_params = [rref]
for param in ddp_model.parameters():
optimizer_params.append(RRef(param))
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
output_device=rank
)
with model.join():
for _ in range(5):
for inp in inputs:
loss = model(inp).sum()
def _register_comm_hook(
self,
state: object,
hook: callable):
def fp16_compress_hook(
process_group: object,
bucket: dist._GradBucket):
compressed_tensor =
# Python
output = torch.where(score_over_threshold, label, unknown_labels)
// C++
const auto output = torch::where(score_over_threshold, label, unknown_labels);
def __init__(self) -> None:
super().__init__()
self.model = self._load_model(pretrained_backbone=True).cuda()
self.model = DistributedDataParallel(self.model, device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device())
train_ds <- kmnist_dataset(
".",
download = TRUE,
train = TRUE,
transform = transform_to_tensor
)
test_ds <- kmnist_dataset(
".",
download = TRUE,
python -m torch_xla.distributed.xla_dist \
--tpu=$TPU_NAME \
--conda-env=torch-xla-1.6 \
--env ANY_ENV_VAR=VALUE \
-- \
python /path/to/your/code.py --train_arg1 \
--train_arg2 ...
# Import Libraries
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
# Model Definition
# Imports
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl