This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
t1 = torch.rand((3, 3), requires_grad=True) | |
t2 = torch.rand((3, 3), requires_grad=True) | |
rref = rpc.remote("worker1", torch.add, | |
args=(t1, t2)) | |
ddp_model = DDP(my_model) | |
# Setup optimizer | |
optimizer_params = [rref] | |
for param in ddp_model.parameters(): | |
optimizer_params.append(RRef(param)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = torch.nn.parallel.DistributedDataParallel( | |
model, | |
device_ids=[rank], | |
output_device=rank, | |
gradient_as_bucket_view=True | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = torch.nn.parallel.DistributedDataParallel( | |
model, | |
device_ids=[rank], | |
output_device=rank | |
) | |
with model.join(): | |
for _ in range(5): | |
for inp in inputs: | |
loss = model(inp).sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _register_comm_hook( | |
self, | |
state: object, | |
hook: callable): | |
def fp16_compress_hook( | |
process_group: object, | |
bucket: dist._GradBucket): | |
compressed_tensor = |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python | |
output = torch.where(score_over_threshold, label, unknown_labels) | |
// C++ | |
const auto output = torch::where(score_over_threshold, label, unknown_labels); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __init__(self) -> None: | |
super().__init__() | |
self.model = self._load_model(pretrained_backbone=True).cuda() | |
self.model = DistributedDataParallel(self.model, device_ids=[torch.cuda.current_device()], | |
output_device=torch.cuda.current_device()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_ds <- kmnist_dataset( | |
".", | |
download = TRUE, | |
train = TRUE, | |
transform = transform_to_tensor | |
) | |
test_ds <- kmnist_dataset( | |
".", | |
download = TRUE, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python -m torch_xla.distributed.xla_dist \ | |
--tpu=$TPU_NAME \ | |
--conda-env=torch-xla-1.6 \ | |
--env ANY_ENV_VAR=VALUE \ | |
-- \ | |
python /path/to/your/code.py --train_arg1 \ | |
--train_arg2 ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import Libraries | |
import torch | |
import torch.nn as nn | |
import torch.multiprocessing as mp | |
import torch.distributed as dist | |
import torchvision | |
import torchvision.transforms as transforms | |
# Model Definition |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as transforms | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.parallel_loader as pl |
NewerOlder