Skip to content

Instantly share code, notes, and snippets.

@jwatte
Last active June 2, 2022 15:15
Show Gist options
  • Save jwatte/c744cace32961d55465f29123c88a779 to your computer and use it in GitHub Desktop.
Save jwatte/c744cace32961d55465f29123c88a779 to your computer and use it in GitHub Desktop.
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.t5 import t5_encode_text, DEFAULT_T5_NAME
from torchvision import transforms
# This doesn't yet format the text embedding tensors right
# TODO: check out https://gist.github.com/Netruk44/38d793e6d04a53cc4d9acbfadbb04a5c
import json
from PIL import Image
import sys
import random
img_to_tensor = transforms.ToTensor()
def progress(s):
sys.stderr.write(s)
valortrain = 'val2017'
maxtrain = 16
itersperimgset = 50
imgsetspercheckpoint = 50
maxcheckpoints = 50
infile = 'coco/annotations/captions_%s.json' % valortrain
progress("Loading JSON from %s.\n" % infile)
with open(infile, 'r') as f:
annot = json.load(f)
numimg = len(annot['images'])
numannot = len(annot['annotations'])
if numannot < maxtrain:
maxtrain = numannot
imgpathbyid = {}
imglist = annot['images']
for k in range(0, len(imglist)):
i = imglist[k]
imgpathbyid[i['id']] = 'coco/images/%s/%s' % (valortrain, i['file_name'])
progress("Allocating %d slots for %d images and %d annotations from %s\n" %
(maxtrain, numimg, numannot, infile))
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.zeros(maxtrain, 256, 1024).cuda()
text_masks = torch.ones(maxtrain, 256).bool().cuda()
images = torch.zeros(maxtrain, 3, 256, 256).cuda()
# unet for imagen
unet1 = Unet(
dim=32,
cond_dim=512,
dim_mults=(1, 2, 4, 8),
num_resnet_blocks=3,
layer_attns=(False, True, True, True),
)
unet2 = Unet(
dim=32,
cond_dim=512,
dim_mults=(1, 2, 4, 8),
num_resnet_blocks=(2, 4, 8, 8),
layer_attns=(False, False, False, True),
layer_cross_attns=(False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets=(unet1, unet2),
text_encoder_name='t5-large',
image_sizes=(64, 256),
beta_schedules=('cosine', 'linear'),
timesteps=1000,
cond_drop_prob=0.5
).cuda()
# wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
# Pick N random annotations from the list of annotations.
# Load/ccale/sample the corresponding image.
# Generate the T5 text embedding of the given prompt.
# Upload the data to the appropriate slot in each tensor.
def load_images(annots, pathsbyid, tensordim, t_embeds, t_masks, t_images):
# pick N random annotations
todo = annots.copy()
random.shuffle(todo)
todo = todo[0: tensordim]
for ix in range(0, tensordim):
atxt = todo[ix]['caption']
iid = todo[ix]['image_id']
imgpath = pathsbyid[iid]
progress("%d, %s, %s\n" % (ix, imgpath, atxt))
img = Image.open(imgpath)
width = img.width
height = img.height
# pick a random square sub-block of the image
# TODO: maybe subsample a little more, for stretching?
if width > height:
img = img.resize((int(256*width/height), 256), Image.BICUBIC)
else:
img = img.resize((256, int(256*height/width)), Image.BICUBIC)
if img.width > 256:
l = int(random.uniform(0, img.width-256))
img = img.crop((l, 0, l+256, 256))
else:
l = int(random.uniform(0, img.height-256))
img = img.crop((0, l, 256, l+256))
t_images[ix, :] = img_to_tensor(img)
text_embeds, text_masks = t5_encode_text([atxt], name=DEFAULT_T5_NAME)
t_embeds[ix, :] = text_embeds.cuda()
t_masks[ix, :] = text_masks.cuda()
del text_embeds
del text_masks
# feed images into imagen, training each unet in the cascade
for cp in range(0, maxcheckpoints):
for iset in range(0, imgsetspercheckpoint):
progress("checkpoint %d imageset %d iters %d\n" %
(cp, iset, itersperimgset))
load_images(annot['annotations'], imgpathbyid,
maxtrain, text_embeds, text_masks, images)
for iter in range(0, itersperimgset):
# train the networks
progress("iter %d/%d/%d\n" % (cp, iset, iter))
for i in (1, 2):
loss = trainer(images, text_embeds=text_embeds,
text_masks=text_masks, unet_number=i)
trainer.update(unet_number=i)
images = trainer.sample(texts=[
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale=2.)
print(images.shape) # or whatever
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment