Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 20, 2021 15:08
Show Gist options
  • Save crowsonkb/a93904fbb88aff0302aac98dfdb26b5f to your computer and use it in GitHub Desktop.
Save crowsonkb/a93904fbb88aff0302aac98dfdb26b5f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
from collections import defaultdict
import csv
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms
from tqdm import tqdm
TRAIN_ANN = 'annotations/captions_train2017.json'
TRAIN_ROOT = 'train2017-160'
VAL_ANN = 'annotations/captions_val2017.json'
VAL_ROOT = 'val2017-160'
BATCH_SIZE = 2500
MICROBATCH_SIZE = 50
PREFIX = 'clip_coco_2'
SEQ_LEN = 630
BYTES = [10, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 91, 92, 93, 95, 96, 97, 98, 99, 100,
101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,
118, 119, 120, 121, 122]
TOKENS = defaultdict(lambda: 0, zip(BYTES, range(1, len(BYTES) + 1)))
class ConvBlock(nn.Sequential):
def __init__(self, c_in, c_out):
super().__init__(
nn.Conv2d(c_in, c_out, 3, padding=1),
nn.ReLU(inplace=True),
)
class ImageEncoder(nn.Sequential):
def __init__(self):
super().__init__(
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ConvBlock(3, 64),
ConvBlock(64, 64),
nn.MaxPool2d(2),
ConvBlock(64, 128),
ConvBlock(128, 128),
nn.MaxPool2d(2),
ConvBlock(128, 128),
ConvBlock(128, 128),
nn.MaxPool2d(2),
ConvBlock(128, 128),
ConvBlock(128, 128),
nn.MaxPool2d(2),
ConvBlock(128, 128),
ConvBlock(128, 128),
nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d([4, 4]),
nn.Flatten(),
nn.Linear(128 * 4 * 4, 64),
)
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
d_model = 256
self.embed = nn.Embedding(max(TOKENS.values()) + 1, d_model)
layer = nn.TransformerEncoderLayer(d_model, 4, d_model)
self.encoder = nn.TransformerEncoder(layer, 6)
self.proj = nn.Linear(d_model, 64)
pos = torch.arange(SEQ_LEN)
dim = torch.arange(d_model)
pos, dim = torch.meshgrid([pos, dim])
ramp = pos / 10000**(2 * dim / d_model)
pe = torch.where(dim % 2 == 0, torch.sin(ramp), torch.cos(ramp))
self.register_buffer('pe', pe)
def forward(self, input):
mask = (input == 0).T
embed = self.embed(input) + self.pe[:, None, :]
return self.proj(self.encoder(embed, src_key_padding_mask=mask)[-1])
class CLIPLoss(nn.Module):
def __init__(self):
super().__init__()
self.t = nn.Parameter(torch.tensor(0.))
def forward(self, image_embed, text_embed):
n = image_embed.shape[0]
image_embed = F.normalize(image_embed)
text_embed = F.normalize(text_embed)
logits = image_embed @ text_embed.T * torch.exp(self.t)
labels = torch.arange(n, device=self.t.device)
loss_i = F.cross_entropy(logits, labels)
loss_t = F.cross_entropy(logits.T, labels)
acc_i = (torch.argmax(logits, dim=1) == labels).sum()
acc_t = (torch.argmax(logits, dim=0) == labels).sum()
return (loss_i + loss_t) / 2, (acc_i + acc_t) / n / 2
def get_seq_len_and_tokens(datasets):
seq_len = 0
unique_bytes = set()
for dataset in datasets:
for item in tqdm(dataset):
seq = ' '.join(item[1]).lower().encode()
seq_len = max(len(seq), seq_len)
for b in seq:
unique_bytes.add(b)
return seq_len, sorted(unique_bytes)
def collate(samples):
image_batch = torch.stack([s[0] for s in samples])
texts = [' '.join(s[1]).lower().encode() for s in samples]
texts = [list(text.rjust(SEQ_LEN, b'\0')) for text in texts]
texts = [[TOKENS[b] for b in text] for text in texts]
text_batch = torch.tensor(texts).T
return image_batch, text_batch
def main():
p = argparse.ArgumentParser()
p.add_argument('--seed', type=int, default=0, help='the random seed')
args = p.parse_args()
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
train_tf = transforms.Compose([
transforms.Resize([160, 160]),
transforms.RandomCrop([128, 128]),
transforms.ToTensor(),
])
val_tf = transforms.Compose([
transforms.Resize([160, 160]),
transforms.CenterCrop([128, 128]),
transforms.ToTensor(),
])
train_set = datasets.CocoCaptions(TRAIN_ROOT, TRAIN_ANN, transform=train_tf)
train_dl = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2, collate_fn=collate, pin_memory=True)
val_set = datasets.CocoCaptions(VAL_ROOT, VAL_ANN, transform=val_tf)
val_dl = data.DataLoader(val_set, batch_size=BATCH_SIZE,
num_workers=2, collate_fn=collate, pin_memory=True)
image_enc = ImageEncoder().to(device)
text_enc = TextEncoder().to(device)
clip_loss = CLIPLoss().to(device)
print('Image encoder parameters:', sum(x.numel() for x in image_enc.parameters()))
print('Text encoder parameters:', sum(x.numel() for x in text_enc.parameters()))
print('CLIP loss parameters:', sum(x.numel() for x in clip_loss.parameters()))
params = [*image_enc.parameters(), *text_enc.parameters(), *clip_loss.parameters()]
opt = optim.AdamW(params, lr=1e-4, weight_decay=0.01)
epoch = 1
csvfile = open(PREFIX + '.csv', 'w')
writer = csv.writer(csvfile)
writer.writerow(['epoch', 'loss', 'accuracy'])
csvfile.flush()
def train():
image_enc.train()
text_enc.train()
clip_loss.train()
i = 0
for image_batch, text_batch in tqdm(train_dl):
i += 1
image_batch = image_batch.to(device, non_blocking=True)
text_batch = text_batch.to(device, non_blocking=True)
n = math.ceil(BATCH_SIZE / MICROBATCH_SIZE)
image_mbs = torch.chunk(image_batch, n)
text_mbs = torch.chunk(text_batch, n, dim=1)
with torch.no_grad():
images = [image_enc(mb) for mb in image_mbs]
texts = [text_enc(mb) for mb in text_mbs]
loss, acc = clip_loss(torch.cat(images), torch.cat(texts))
tqdm.write(f'{i} {loss.item():g} {acc.item():g}')
opt.zero_grad()
for j, mb in enumerate(image_mbs):
images_tmp = images.copy()
images_tmp[j] = image_enc(mb)
loss, _ = clip_loss(torch.cat(images_tmp), torch.cat(texts))
loss.backward()
for j, mb in enumerate(text_mbs):
texts_tmp = texts.copy()
texts_tmp[j] = text_enc(mb)
loss, _ = clip_loss(torch.cat(images), torch.cat(texts_tmp))
loss.backward()
opt.step()
def val():
print('Validating...')
image_enc.eval()
text_enc.eval()
clip_loss.eval()
losses, accs = [], []
for image_batch, text_batch in tqdm(val_dl):
image_batch = image_batch.to(device, non_blocking=True)
text_batch = text_batch.to(device, non_blocking=True)
n = math.ceil(BATCH_SIZE / MICROBATCH_SIZE)
with torch.no_grad():
images = [image_enc(mb) for mb in torch.chunk(image_batch, n)]
texts = [text_enc(mb) for mb in torch.chunk(text_batch, n, dim=1)]
loss, acc = clip_loss(torch.cat(images), torch.cat(texts))
losses.append(loss.item() * len(image_batch))
accs.append(acc.item() * len(image_batch))
avg_loss = sum(losses) / len(val_set)
avg_acc = sum(accs) / len(val_set)
print(f'Validation loss: {avg_loss:g}, accuracy: {avg_acc:g}')
writer.writerow([epoch, avg_loss, avg_acc])
csvfile.flush()
def save():
state = {'image_enc': image_enc.state_dict(),
'text_enc': text_enc.state_dict(),
'clip_loss': clip_loss.state_dict(),
'opt': opt.state_dict()}
torch.save(state, PREFIX + '.pth')
print(f'Wrote checkpoint to {PREFIX}.pth.')
try:
while True:
print('Epoch', epoch)
train()
val()
save()
epoch += 1
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment