-
-
Save johnmccain/ae2d9f376abff90dfd5cd4c9b36fa7e5 to your computer and use it in GitHub Desktop.
DALLE-pytorch training script to replicate DataParallel error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import os | |
import json | |
import datetime | |
import random | |
import glob | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.optim import AdamW, Adam | |
import torchvision | |
from torchvision import transforms | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.multiprocessing as mp | |
import torch.distributed as dist | |
from tqdm import tqdm | |
from dalle_pytorch import DiscreteVAE, DALLE | |
random.seed(0) | |
def main(): | |
device = torch.device("cuda:0") | |
train_caption_ids = torch.randint(0, 4096, (1000, 256)) | |
train_masks = torch.full((1000, 256), 1).bool() | |
train_image_ids = torch.randint(0, 1024, (1000, 1024)) | |
train_img_dataset = TensorDataset(train_caption_ids, train_masks, train_image_ids) | |
train_loader = DataLoader(train_img_dataset, shuffle=True, batch_size=4) | |
vae = DiscreteVAE( | |
image_size=256, | |
num_tokens=1024, | |
codebook_dim=512, | |
num_layers=3 | |
) | |
dalle = DALLE( | |
dim = 512, | |
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens | |
num_text_tokens = 4096, # vocab size for text | |
text_seq_len = 256, # text sequence length | |
depth = 16, # should aim to be 64 | |
heads = 8, # attention heads | |
dim_head = 64, # attention head dimension | |
attn_dropout = 0.1, # attention dropout | |
ff_dropout = 0.1 # feedforward dropout | |
) | |
dalle = nn.DataParallel(dalle).to(device) | |
optimizer = AdamW(dalle.parameters(), lr=3e-4) | |
for caption_ids, mask, image_tokens in train_loader: | |
dalle.train() | |
dalle.zero_grad() | |
caption_ids = caption_ids.to(device) | |
mask = mask.to(device) | |
image_tokens = image_tokens.to(device) | |
loss = dalle(caption_ids, image_tokens, mask = mask, return_loss = True) | |
loss.backward() | |
optimizer.step() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment