Skip to content

Instantly share code, notes, and snippets.

@johnmccain
Created February 12, 2021 05:14
Show Gist options
  • Save johnmccain/ae2d9f376abff90dfd5cd4c9b36fa7e5 to your computer and use it in GitHub Desktop.
Save johnmccain/ae2d9f376abff90dfd5cd4c9b36fa7e5 to your computer and use it in GitHub Desktop.
DALLE-pytorch training script to replicate DataParallel error
import pathlib
import os
import json
import datetime
import random
import glob
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW, Adam
import torchvision
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
from tqdm import tqdm
from dalle_pytorch import DiscreteVAE, DALLE
random.seed(0)
def main():
device = torch.device("cuda:0")
train_caption_ids = torch.randint(0, 4096, (1000, 256))
train_masks = torch.full((1000, 256), 1).bool()
train_image_ids = torch.randint(0, 1024, (1000, 1024))
train_img_dataset = TensorDataset(train_caption_ids, train_masks, train_image_ids)
train_loader = DataLoader(train_img_dataset, shuffle=True, batch_size=4)
vae = DiscreteVAE(
image_size=256,
num_tokens=1024,
codebook_dim=512,
num_layers=3
)
dalle = DALLE(
dim = 512,
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens = 4096, # vocab size for text
text_seq_len = 256, # text sequence length
depth = 16, # should aim to be 64
heads = 8, # attention heads
dim_head = 64, # attention head dimension
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
)
dalle = nn.DataParallel(dalle).to(device)
optimizer = AdamW(dalle.parameters(), lr=3e-4)
for caption_ids, mask, image_tokens in train_loader:
dalle.train()
dalle.zero_grad()
caption_ids = caption_ids.to(device)
mask = mask.to(device)
image_tokens = image_tokens.to(device)
loss = dalle(caption_ids, image_tokens, mask = mask, return_loss = True)
loss.backward()
optimizer.step()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment