Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Last active March 15, 2021 20:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2 to your computer and use it in GitHub Desktop.
Save afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2 to your computer and use it in GitHub Desktop.
import wandb
from functools import lru_cache
from collections import OrderedDict
import os
import imghdr
import argparse
from random import choice
from pathlib import Path
from os import listdir
from PIL import Image
# torch
import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
# vision imports
from PIL import Image
import pandas as pd
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
# dalle related classes and utils
from dalle_pytorch import DiscreteVAE, DALLE, OpenAIDiscreteVAE
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE
# argument parsing
parser = argparse.ArgumentParser()
#group = parser.add_mutually_exclusive_group(required = False)
#group.add_argument('--vae_path', type = str, help='path to your trained discrete VAE')
#group.add_argument('--dalle_path', type = str, help='path to your partially trained DALL-E')
parser.add_argument('--depth', type = int, help='Number of hidden layers i think? Should strive for this to be 64')
parser.add_argument('--heads', type = int, help='Number of heads? Default is 8')
# parser.add_argument('--dim_head', type = float, help="Dimensions of your heads?")
parser.add_argument('--learning_rate', type = float, help="Default is 3e-4")
# parser.add_argument('--shuffle', type = bool, help="Whether or not to shuffle your dataset."
# parser.add_argument('--resume', type = int, help="Start over from an existing session (broken currently)"
parser.add_argument('--batch_size', type = int, help="Tough to go beyond 4 with this. ")
parser.add_argument('--grad_clip_norm', type = float, help="Not sure what this does. Default is 0.5")
parser.add_argument('--top_k', type = float, help="Not sure what this does. Default is 0.9")
# parser.add_argument('--reversible', type = int, help="Allows you to run at much higher depth at the cost of compute time."
# parser.add_argument('--model_dim', type = int, help="Dimensions of your model.")
# parser.add_argument('--attn_types', choices=['full', 'axial_row', 'axial_col', 'conv_like'])
args = parser.parse_args()
# helpers
def exists(val):
return val is not None
# constants
#VAE_PATH = args.vae_path
#DALLE_PATH = args.dalle_path
PROJECT_NAME = "a100_frac_0_1"
VAE_PATH = None
DALLE_PATH = None #'./dalle.pt'
RESUME = exists(DALLE_PATH)
SHUFFLE = True
EPOCHS = 28
BATCH_SIZE = args.batch_size
LEARNING_RATE = args.learning_rate
GRAD_CLIP_NORM = args.grad_clip_norm
DEPTH = args.depth
HEADS = args.heads
TOP_K = args.top_k
MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DIM_HEAD = 64
REVERSIBLE = True,
ATTN_TYPES = ('full')
IMAGE_TEXT_FOLDER = '/root/oai_images'
images = os.listdir(IMAGE_TEXT_FOLDER)
all_img_filenames = [image for image in os.listdir(IMAGE_TEXT_FOLDER) if ".png" in image]
# df = pd.read_csv("origithink.csv", delimiter='\t')
df = pd.read_parquet("/root/oai_file_text_pairs.parq")
try:
df = df.loc[all_img_filenames]
df = df.shuffle(frac=0.1)
except KeyError:
print("Telling pandas to stfu")
txt_lookup_dict = df.to_dict(into=OrderedDict)['text']
# constants
@lru_cache(maxsize=None)
def clean_text(text):
return text.replace("{4", "").replace("{3", "").replace("{2", "").replace("{1", "").replace("{0", "").replace("}", "").replace("{", "").replace(" ", " ").replace(" ", " ").replace(":", "").replace(" -", "-").replace("COMMA", ",").replace(" .", ".").replace(" ,", ",").replace(". ", ".").replace(", ", ",")
# reconstitute vae
if RESUME:
dalle_path = Path(DALLE_PATH)
assert dalle_path.exists(), 'DALL-E model file does not exist'
loaded_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
if vae_params is not None:
vae = DiscreteVAE(**vae_params)
else:
vae = OpenAIDiscreteVAE()
# if not args.taming:
# vae = OpenAIDiscreteVAE()
# else:
dalle_params = dict(
**dalle_params
)
IMAGE_SIZE = 256 # vae_params['image_size'] / 2
else:
if exists(VAE_PATH):
vae_path = Path(VAE_PATH)
assert vae_path.exists(), 'VAE model file does not exist'
loaded_obj = torch.load(str(vae_path))
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
vae = DiscreteVAE(**vae_params)
vae.load_state_dict(weights)
else:
print('using pretrained VAE for encoding images to tokens')
vae_params = None
vae_klass = OpenAIDiscreteVAE
# OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
vae = vae_klass()
IMAGE_SIZE = vae.image_size
dalle_params = dict(
num_text_tokens = VOCAB_SIZE,
text_seq_len = TEXT_SEQ_LEN,
dim = MODEL_DIM,
depth = DEPTH,
heads = HEADS,
dim_head = DIM_HEAD,
reversible = REVERSIBLE,
attn_types = ATTN_TYPES,
)
# helpers
def save_model(path):
save_obj = {
'hparams': dalle_params,
'vae_params': vae_params,
'weights': dalle.state_dict()
}
torch.save(save_obj, path)
# dataset loading
# dataset loading
#os.chdir("/root/oai_images")
class TextImagePairs(Dataset):
def __init__(self, image_text_pairs, folder="images", text_len = 256, image_size = 128):
super().__init__()
self._path = Path(folder)
self.image_text_pairs = image_text_pairs
self.keys = list(self.image_text_pairs)
self.image_tranform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor()
])
def __getitem__(self, ind):
image_filename = ""
try:
image_filename = self.keys[ind]
description = clean_text(self.image_text_pairs[image_filename])
image_full_path = os.path.join(self._path, image_filename)
if imghdr.what(image_full_path) == 'png':
image = Image.open(image_full_path)
else:
return self.__getitem__(ind+64)
tokenized_text = tokenize(description).squeeze(0)
mask = tokenized_text != 0
image_tensor = self.image_tranform(image)
return tokenized_text, image_tensor, mask
except OSError:
print(f"Truncated JPEG. Saving wandb, skipping index by 64 and trying again. filename was: {image_filename} at idx {ind}")
wandb.save('./dalle.pt')
return self.__getitem__(ind+64)
def __len__(self):
return len(self.keys)
# create dataset and dataloader
def build_dataset(batch_size, shuffle, text_seq_len, image_size):
text_image_pairs_ds = TextImagePairs(
txt_lookup_dict,
text_len = text_seq_len,
folder = IMAGE_TEXT_FOLDER,
image_size = IMAGE_SIZE
)
assert len(text_image_pairs_ds) > 0, 'dataset is empty'
return DataLoader(text_image_pairs_ds, batch_size=batch_size, shuffle=shuffle, drop_last=True)
dl = build_dataset(BATCH_SIZE, SHUFFLE, TEXT_SEQ_LEN, IMAGE_SIZE)
# initialize DALL-E
dalle = DALLE(vae = vae, **dalle_params).cuda()
if RESUME:
dalle.load_state_dict(weights)
# optimizer
opt = Adam(dalle.parameters(), lr = LEARNING_RATE)
# experiment tracker
config = wandb.config
config.depth = DEPTH
config.heads = HEADS
config.dim_head = DIM_HEAD
config.learning_rate = LEARNING_RATE
config.shuffle = SHUFFLE
config.resume = RESUME
config.batch_size = BATCH_SIZE
config.grad_clip_norm = GRAD_CLIP_NORM
config.reversible = REVERSIBLE
config.model_dim = MODEL_DIM
config.attn_types = ATTN_TYPES
wandb.init(project = PROJECT_NAME, resume = RESUME)
wandb.watch(dalle)
# training
def run_epoch(epoch):
for i, (text, images, mask) in enumerate(dl):
text, images, mask = map(lambda t: t.cuda(), (text, images, mask))
if text is None or images is None or mask is None:
print("Bad generations. Skipping iter if possible.")
continue
loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()
clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)
opt.step()
opt.zero_grad()
log = {}
if i % 10 == 0:
print(epoch, i, f'loss - {loss.item()}')
log = {
**log,
'epoch': epoch,
'iter': i,
'dalle_loss': loss.item()
}
wandb.log(log)
torch.cuda.empty_cache()
if i % 100 == 0:
sample_text = text[:1]
token_list = sample_text.masked_select(sample_text != 0).tolist()
decoded_text = tokenizer.decode(token_list)
image = dalle.generate_images(
text[:1],
mask = mask[:1],
filter_thres = TOP_K
)
save_model(f'./dalle.pt')
torch.cuda.empty_cache()
log = {
**log,
'image': wandb.Image(image, caption = decoded_text)
}
if i % 500: # save less often to prevent /tmp from filling up
wandb.save(f'./dalle.pt')
wandb.log(log)
for epoch in range(EPOCHS):
try:
run_epoch(epoch)
except OSError:
print("Somehow a single truncated image error bubbled all the way up to the epoch level. Skipping epoch (unfortunately) and saving checkpoint to wandb.")
wandb.save('./dalle.pt')
save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment