Skip to content

Instantly share code, notes, and snippets.

@fomalhautb
Last active May 17, 2023 11:29
Show Gist options
  • Save fomalhautb/408298b0429454c45f138f472c838a97 to your computer and use it in GitHub Desktop.
Save fomalhautb/408298b0429454c45f138f472c838a97 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from datetime import datetime\n",
"import numpy as np\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"from torch import optim\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.cuda.amp import autocast, GradScaler\n",
"\n",
"from transformers import BertTokenizer, AdamW\n",
"import transformers\n",
"\n",
"from dalle.dalle_pytorch import DiscreteVAE, DALLE\n",
"\n",
"IMAGE_SIZE = 256\n",
"TRAIN_ANNOT_PATH = '<your path>/captions_train2014.json'\n",
"VAL_ANNOT_PATH = '<your path>/captions_val2014.json'\n",
"TRAIN_IMAGE_PATH = '<your path>/train2014'\n",
"VAL_IMAGE_PATH = '<your path>/val2014'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class COCODataset(Dataset):\n",
" def __init__(self, annot_path, image_path, image_size=256):\n",
" self._image_path = image_path\n",
" \n",
" with open(annot_path) as file:\n",
" json_file = json.load(file)\n",
" \n",
" self._image_size = image_size\n",
" self._metadata = json_file['images']\n",
" self._captions = {entry['image_id']: entry for entry in json_file['annotations']}\n",
" \n",
" def __getitem__(self, index):\n",
" metadata = self._metadata[index]\n",
" caption = self._captions[metadata['id']]\n",
" image_path = os.path.join(self._image_path, metadata['file_name'])\n",
" image = Image.open(image_path).convert('RGB')\n",
" image = self._crop_image(image)\n",
" x = np.asarray(image) / 127.5 - 1\n",
" return torch.Tensor(x).permute(2, 0, 1), caption['caption']\n",
" \n",
" def _crop_image(self, image):\n",
" width, height = image.size\n",
" min_length = min(width, height)\n",
" \n",
" # center crop\n",
" left = (width - min_length)/2\n",
" top = (height - min_length)/2\n",
" right = (width + min_length)/2\n",
" bottom = (height + min_length)/2\n",
" image = image.crop((left, top, right, bottom))\n",
" \n",
" # resize\n",
" image = image.resize((self._image_size, self._image_size))\n",
" \n",
" return image\n",
" \n",
" def __len__(self):\n",
" return len(self._metadata)\n",
" \n",
"\n",
"def tensor_to_image(tensor):\n",
" tensor = tensor.permute(1, 2, 0)\n",
" arr = (tensor.numpy() + 1) * 127.5\n",
" arr = arr.clip(0, 255)\n",
" arr = arr.astype(np.uint8)\n",
" return Image.fromarray(arr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 12\n",
"train_dataset = COCODataset(TRAIN_ANNOT_PATH, TRAIN_IMAGE_PATH, image_size=IMAGE_SIZE)\n",
"val_dataset = COCODataset(VAL_ANNOT_PATH, VAL_IMAGE_PATH, image_size=IMAGE_SIZE)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4)\n",
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_vae(epoch, model, loader, optimizer, device):\n",
" model.train()\n",
" \n",
" for index, (x, _) in enumerate(loader):\n",
" x = x.to(device)\n",
" optimizer.zero_grad()\n",
" \n",
" loss = model(x, return_loss=True)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" print(f'Epoch {epoch:2} [{index:4}/{len(loader):4}]: Loss: {loss:.4}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"TRAIN_VAE = False\n",
"VAE_SAVE_PATH = 'vae.pt'\n",
"DEVICE = 'cuda'\n",
"EPOCHS = 30\n",
"\n",
"vae = DiscreteVAE(\n",
" image_size = IMAGE_SIZE,\n",
" num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)\n",
" num_tokens = 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects\n",
" codebook_dim = 512, # codebook dimension\n",
" hidden_dim = 64, # hidden dimension\n",
" num_resnet_blocks = 1, # number of resnet blocks\n",
" temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization\n",
" straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other,\n",
" kl_div_loss_weight = 0\n",
")\n",
"\n",
"if TRAIN_VAE:\n",
" vae.to(DEVICE)\n",
" \n",
" optimizer = optim.Adam(vae.parameters(), lr=0.001, weight_decay=0.0)\n",
" scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.98)\n",
" \n",
" for epoch in range(1, EPOCHS+1):\n",
" print(f'Epoch {epoch}/{EPOCHS}'.center(50, '='))\n",
" print(f'Learning Rate: {optimizer.param_groups[0][\"lr\"]: .3}')\n",
" train_vae(epoch, vae, train_loader, optimizer, DEVICE)\n",
" scheduler.step()\n",
" \n",
" torch.save(vae.state_dict(), VAE_SAVE_PATH)\n",
"else:\n",
" vae.load_state_dict(torch.load(VAE_SAVE_PATH, map_location=DEVICE))\n",
" vae.to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vae.eval()\n",
"x, y = val_dataset[0]\n",
"codes = vae.get_codebook_indices(x.unsqueeze(0).cuda())\n",
"generated = vae.decode(codes)[0].detach().cpu()\n",
"plt.imshow(tensor_to_image(x))\n",
"plt.show()\n",
"plt.imshow(tensor_to_image(generated))\n",
"plt.show()\n",
"print(y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_dalle(epoch, model, loader, optimizer, tokenizer, pad_len, device, scaler, \n",
" example_interval=100, example_text='A man in a red shirt and a red hat is on a motorcycle on a hill side.'): \n",
" start_time = datetime.now()\n",
" accumulated_loss = 0\n",
" for index, (x, y) in enumerate(loader):\n",
" model.train()\n",
" x = x.to(device)\n",
" optimizer.zero_grad()\n",
" tokenized = tokenizer(y, return_tensors='pt', padding='max_length', truncation=True, max_length=pad_len)\n",
" \n",
" with autocast():\n",
" loss = model(\n",
" tokenized['input_ids'].to(device), \n",
" x, \n",
" mask=tokenized['attention_mask'].bool().to(device), \n",
" return_loss=True\n",
" )\n",
"# loss.backward()\n",
"# optimizer.step()\n",
" scaler.scale(loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" accumulated_loss += loss.item()\n",
" \n",
" if (index + 1) % example_interval == 0:\n",
" tokenized = tokenizer([example_text], return_tensors='pt', padding='max_length', truncation=True, max_length=pad_len)\n",
" model.eval()\n",
" \n",
" with autocast():\n",
" images = model.generate_images(\n",
" tokenized['input_ids'].to(device), \n",
" mask=tokenized['attention_mask'].bool().to(device)\n",
" )\n",
" \n",
" plt.imshow(tensor_to_image(images[0].detach().cpu()))\n",
" plt.show()\n",
" \n",
" print(f'Epoch {epoch:2} [{index:4}/{len(loader):4}]: Loss: {loss:.4f}, Accumulated Loss: {accumulated_loss:.4f}'\n",
" f'ETA: {str((len(loader) - (index + 1)) / (index + 1) * (datetime.now() - start_time)).split(\".\")[0]}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"TRAIN_DALLE = True\n",
"DALLE_SAVE_PATH = 'dalle.pt'\n",
"DEVICE = 'cuda'\n",
"EPOCHS = 30\n",
"TEXT_SEQ_LEN = 128\n",
"\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"dalle = DALLE(\n",
" dim = 1024,\n",
" vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens\n",
" num_text_tokens = tokenizer.vocab_size, # vocab size for text\n",
" text_seq_len = TEXT_SEQ_LEN, # text sequence length\n",
" depth = 6, # should aim to be 64\n",
" heads = 8, # attention heads\n",
" dim_head = 64, # attention head dimension\n",
" attn_dropout = 0.1, # attention dropout\n",
" ff_dropout = 0.1 # feedforward dropout\n",
")\n",
"\n",
"if TRAIN_DALLE:\n",
" dalle.to(DEVICE)\n",
" scaler = GradScaler()\n",
" \n",
" optimizer = AdamW(dalle.parameters(), lr=1e-4, weight_decay=0.0)\n",
" scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.98)\n",
" \n",
" for epoch in range(1, EPOCHS+1):\n",
" print(f'Epoch {epoch}/{EPOCHS}'.center(50, '='))\n",
" print(f'Learning Rate: {optimizer.param_groups[0][\"lr\"]: .3}')\n",
" train_dalle(epoch, dalle, train_loader, optimizer, tokenizer, TEXT_SEQ_LEN, DEVICE, scaler, example_interval=200)\n",
" scheduler.step()\n",
" \n",
" torch.save(dalle.state_dict(), DALLE_SAVE_PATH)\n",
"else:\n",
" dalle.load_state_dict(torch.load(DALLE_SAVE_PATH, map_location=DEVICE))\n",
" dalle.to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"GEN_NUM = 5\n",
"GEN_BATCH = 8\n",
"for j in range(GEN_NUM):\n",
" example_text = 'A man on a horse'\n",
" tokenized = tokenizer(\n",
" [example_text for _ in range(GEN_BATCH)], \n",
" return_tensors='pt', \n",
" padding='max_length', \n",
" truncation=True, \n",
" max_length=128\n",
" )\n",
"\n",
" dalle.eval()\n",
" with autocast():\n",
" images = dalle.generate_images(\n",
" tokenized['input_ids'].to(DEVICE), \n",
" mask=tokenized['attention_mask'].bool().to(DEVICE),\n",
" filter_thres=0.5\n",
" )\n",
"\n",
" for i in range(GEN_BATCH):\n",
" plt.axis('off')\n",
" plt.imshow(tensor_to_image(images[i].detach().cpu()))\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment