Last active
March 15, 2021 20:03
-
-
Save afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
from functools import lru_cache | |
from collections import OrderedDict | |
import os | |
import imghdr | |
import argparse | |
from random import choice | |
from pathlib import Path | |
from os import listdir | |
from PIL import Image | |
# torch | |
import torch | |
from torch.optim import Adam | |
from torch.nn.utils import clip_grad_norm_ | |
# vision imports | |
from PIL import Image | |
import pandas as pd | |
from torchvision import transforms as T | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision.datasets import ImageFolder | |
from torchvision.utils import make_grid, save_image | |
# dalle related classes and utils | |
from dalle_pytorch import DiscreteVAE, DALLE, OpenAIDiscreteVAE | |
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE | |
# argument parsing | |
parser = argparse.ArgumentParser() | |
#group = parser.add_mutually_exclusive_group(required = False) | |
#group.add_argument('--vae_path', type = str, help='path to your trained discrete VAE') | |
#group.add_argument('--dalle_path', type = str, help='path to your partially trained DALL-E') | |
parser.add_argument('--depth', type = int, help='Number of hidden layers i think? Should strive for this to be 64') | |
parser.add_argument('--heads', type = int, help='Number of heads? Default is 8') | |
# parser.add_argument('--dim_head', type = float, help="Dimensions of your heads?") | |
parser.add_argument('--learning_rate', type = float, help="Default is 3e-4") | |
# parser.add_argument('--shuffle', type = bool, help="Whether or not to shuffle your dataset." | |
# parser.add_argument('--resume', type = int, help="Start over from an existing session (broken currently)" | |
parser.add_argument('--batch_size', type = int, help="Tough to go beyond 4 with this. ") | |
parser.add_argument('--grad_clip_norm', type = float, help="Not sure what this does. Default is 0.5") | |
parser.add_argument('--top_k', type = float, help="Not sure what this does. Default is 0.9") | |
# parser.add_argument('--reversible', type = int, help="Allows you to run at much higher depth at the cost of compute time." | |
# parser.add_argument('--model_dim', type = int, help="Dimensions of your model.") | |
# parser.add_argument('--attn_types', choices=['full', 'axial_row', 'axial_col', 'conv_like']) | |
args = parser.parse_args() | |
# helpers | |
def exists(val): | |
return val is not None | |
# constants | |
#VAE_PATH = args.vae_path | |
#DALLE_PATH = args.dalle_path | |
PROJECT_NAME = "a100_frac_0_1" | |
VAE_PATH = None | |
DALLE_PATH = None #'./dalle.pt' | |
RESUME = exists(DALLE_PATH) | |
SHUFFLE = True | |
EPOCHS = 28 | |
BATCH_SIZE = args.batch_size | |
LEARNING_RATE = args.learning_rate | |
GRAD_CLIP_NORM = args.grad_clip_norm | |
DEPTH = args.depth | |
HEADS = args.heads | |
TOP_K = args.top_k | |
MODEL_DIM = 512 | |
TEXT_SEQ_LEN = 256 | |
DIM_HEAD = 64 | |
REVERSIBLE = True, | |
ATTN_TYPES = ('full') | |
IMAGE_TEXT_FOLDER = '/root/oai_images' | |
images = os.listdir(IMAGE_TEXT_FOLDER) | |
all_img_filenames = [image for image in os.listdir(IMAGE_TEXT_FOLDER) if ".png" in image] | |
# df = pd.read_csv("origithink.csv", delimiter='\t') | |
df = pd.read_parquet("/root/oai_file_text_pairs.parq") | |
try: | |
df = df.loc[all_img_filenames] | |
df = df.shuffle(frac=0.1) | |
except KeyError: | |
print("Telling pandas to stfu") | |
txt_lookup_dict = df.to_dict(into=OrderedDict)['text'] | |
# constants | |
@lru_cache(maxsize=None) | |
def clean_text(text): | |
return text.replace("{4", "").replace("{3", "").replace("{2", "").replace("{1", "").replace("{0", "").replace("}", "").replace("{", "").replace(" ", " ").replace(" ", " ").replace(":", "").replace(" -", "-").replace("COMMA", ",").replace(" .", ".").replace(" ,", ",").replace(". ", ".").replace(", ", ",") | |
# reconstitute vae | |
if RESUME: | |
dalle_path = Path(DALLE_PATH) | |
assert dalle_path.exists(), 'DALL-E model file does not exist' | |
loaded_obj = torch.load(str(dalle_path)) | |
dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'] | |
if vae_params is not None: | |
vae = DiscreteVAE(**vae_params) | |
else: | |
vae = OpenAIDiscreteVAE() | |
# if not args.taming: | |
# vae = OpenAIDiscreteVAE() | |
# else: | |
dalle_params = dict( | |
**dalle_params | |
) | |
IMAGE_SIZE = 256 # vae_params['image_size'] / 2 | |
else: | |
if exists(VAE_PATH): | |
vae_path = Path(VAE_PATH) | |
assert vae_path.exists(), 'VAE model file does not exist' | |
loaded_obj = torch.load(str(vae_path)) | |
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] | |
vae = DiscreteVAE(**vae_params) | |
vae.load_state_dict(weights) | |
else: | |
print('using pretrained VAE for encoding images to tokens') | |
vae_params = None | |
vae_klass = OpenAIDiscreteVAE | |
# OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 | |
vae = vae_klass() | |
IMAGE_SIZE = vae.image_size | |
dalle_params = dict( | |
num_text_tokens = VOCAB_SIZE, | |
text_seq_len = TEXT_SEQ_LEN, | |
dim = MODEL_DIM, | |
depth = DEPTH, | |
heads = HEADS, | |
dim_head = DIM_HEAD, | |
reversible = REVERSIBLE, | |
attn_types = ATTN_TYPES, | |
) | |
# helpers | |
def save_model(path): | |
save_obj = { | |
'hparams': dalle_params, | |
'vae_params': vae_params, | |
'weights': dalle.state_dict() | |
} | |
torch.save(save_obj, path) | |
# dataset loading | |
# dataset loading | |
#os.chdir("/root/oai_images") | |
class TextImagePairs(Dataset): | |
def __init__(self, image_text_pairs, folder="images", text_len = 256, image_size = 128): | |
super().__init__() | |
self._path = Path(folder) | |
self.image_text_pairs = image_text_pairs | |
self.keys = list(self.image_text_pairs) | |
self.image_tranform = T.Compose([ | |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
T.Resize(image_size), | |
T.CenterCrop(image_size), | |
T.ToTensor() | |
]) | |
def __getitem__(self, ind): | |
image_filename = "" | |
try: | |
image_filename = self.keys[ind] | |
description = clean_text(self.image_text_pairs[image_filename]) | |
image_full_path = os.path.join(self._path, image_filename) | |
if imghdr.what(image_full_path) == 'png': | |
image = Image.open(image_full_path) | |
else: | |
return self.__getitem__(ind+64) | |
tokenized_text = tokenize(description).squeeze(0) | |
mask = tokenized_text != 0 | |
image_tensor = self.image_tranform(image) | |
return tokenized_text, image_tensor, mask | |
except OSError: | |
print(f"Truncated JPEG. Saving wandb, skipping index by 64 and trying again. filename was: {image_filename} at idx {ind}") | |
wandb.save('./dalle.pt') | |
return self.__getitem__(ind+64) | |
def __len__(self): | |
return len(self.keys) | |
# create dataset and dataloader | |
def build_dataset(batch_size, shuffle, text_seq_len, image_size): | |
text_image_pairs_ds = TextImagePairs( | |
txt_lookup_dict, | |
text_len = text_seq_len, | |
folder = IMAGE_TEXT_FOLDER, | |
image_size = IMAGE_SIZE | |
) | |
assert len(text_image_pairs_ds) > 0, 'dataset is empty' | |
return DataLoader(text_image_pairs_ds, batch_size=batch_size, shuffle=shuffle, drop_last=True) | |
dl = build_dataset(BATCH_SIZE, SHUFFLE, TEXT_SEQ_LEN, IMAGE_SIZE) | |
# initialize DALL-E | |
dalle = DALLE(vae = vae, **dalle_params).cuda() | |
if RESUME: | |
dalle.load_state_dict(weights) | |
# optimizer | |
opt = Adam(dalle.parameters(), lr = LEARNING_RATE) | |
# experiment tracker | |
config = wandb.config | |
config.depth = DEPTH | |
config.heads = HEADS | |
config.dim_head = DIM_HEAD | |
config.learning_rate = LEARNING_RATE | |
config.shuffle = SHUFFLE | |
config.resume = RESUME | |
config.batch_size = BATCH_SIZE | |
config.grad_clip_norm = GRAD_CLIP_NORM | |
config.reversible = REVERSIBLE | |
config.model_dim = MODEL_DIM | |
config.attn_types = ATTN_TYPES | |
wandb.init(project = PROJECT_NAME, resume = RESUME) | |
wandb.watch(dalle) | |
# training | |
def run_epoch(epoch): | |
for i, (text, images, mask) in enumerate(dl): | |
text, images, mask = map(lambda t: t.cuda(), (text, images, mask)) | |
if text is None or images is None or mask is None: | |
print("Bad generations. Skipping iter if possible.") | |
continue | |
loss = dalle(text, images, mask = mask, return_loss = True) | |
loss.backward() | |
clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM) | |
opt.step() | |
opt.zero_grad() | |
log = {} | |
if i % 10 == 0: | |
print(epoch, i, f'loss - {loss.item()}') | |
log = { | |
**log, | |
'epoch': epoch, | |
'iter': i, | |
'dalle_loss': loss.item() | |
} | |
wandb.log(log) | |
torch.cuda.empty_cache() | |
if i % 100 == 0: | |
sample_text = text[:1] | |
token_list = sample_text.masked_select(sample_text != 0).tolist() | |
decoded_text = tokenizer.decode(token_list) | |
image = dalle.generate_images( | |
text[:1], | |
mask = mask[:1], | |
filter_thres = TOP_K | |
) | |
save_model(f'./dalle.pt') | |
torch.cuda.empty_cache() | |
log = { | |
**log, | |
'image': wandb.Image(image, caption = decoded_text) | |
} | |
if i % 500: # save less often to prevent /tmp from filling up | |
wandb.save(f'./dalle.pt') | |
wandb.log(log) | |
for epoch in range(EPOCHS): | |
try: | |
run_epoch(epoch) | |
except OSError: | |
print("Somehow a single truncated image error bubbled all the way up to the epoch level. Skipping epoch (unfortunately) and saving checkpoint to wandb.") | |
wandb.save('./dalle.pt') | |
save_model(f'./dalle-final.pt') | |
wandb.save('./dalle-final.pt') | |
wandb.finish() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment