Skip to content

Instantly share code, notes, and snippets.

@l4rz
Created February 25, 2021 18:35
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save l4rz/7040835c3f8266d8b8ea3615a0b49494 to your computer and use it in GitHub Desktop.
Save l4rz/7040835c3f8266d8b8ea3615a0b49494 to your computer and use it in GitHub Desktop.
ALEPH by @advadnoun but for local execution
#
# ALEPH by Advadnoun, https://colab.research.google.com/drive/1Q-TbYvASMPRMXCOQjkxxf72CXYjR_8Vp
# "This is a notebook that uses DALL-E's decoder and CLIP to generate images from text. I will very likely make this better & easier to use in the future."
#
# rearranged to run locally on faster GPU
#
# directions:
# clone https://github.com/openai/DALL-E/ and https://github.com/openai/CLIP
# copy relevant files into one dir with this script
# install torch==1.7.1 and other stuff
# change text
# run
#
# (loss -6.38, 4000 iters, lr=0.5)
#
import torch
import numpy as np
import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import PIL
import random
import imageio
import clip
import torch
import io
import requests
from dall_e import map_pixels, unmap_pixels, load_model
clip.available_models()
def save_img(step, img, pre_scaled=True):
img = np.array(img)[:,:,:]
img = np.transpose(img, (1, 2, 0))
if not pre_scaled:
img = scale(img, 48*4, 32*4)
imageio.imwrite(str(step) + '.png', np.array(img))
return
def preprocess(img):
s = min(img.size)
if s < target_image_size:
raise ValueError(f'min dim for image {s} < {target_image_size}')
r = target_image_size / s
s = (round(r * img.size[1]), round(r * img.size[0]))
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
img = TF.center_crop(img, output_size=2 * [target_image_size])
img = torch.unsqueeze(T.ToTensor()(img), 0)
return map_pixels(img)
class Pars(torch.nn.Module):
def __init__(self):
super(Pars, self).__init__()
self.normu = torch.nn.Parameter(torch.randn(1, 8192, 64, 64).cuda())
def forward(self):
normu = torch.nn.functional.gumbel_softmax(self.normu.view(1, 8192, -1), dim=-1, tau=1.4).view(1, 8192, 64, 64) # tau is temp, default 1
return normu
def checkin(step, loss):
print('Step', step, 'loss', loss)
with torch.no_grad():
al = unmap_pixels(torch.sigmoid(model(lats())[:, :3]).cpu().float()).numpy()
for allls in al:
save_img(step, allls)
#display.display(display.Image(str(3)+'.png'))
#print('\n')
def ascend_txt():
out = unmap_pixels(torch.sigmoid(model(lats())[:, :3].float()))
cutn = 128 # improves quality, was 64
p_s = []
for ch in range(cutn):
size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .98))
offsetx = torch.randint(0, sideX - size, ())
offsety = torch.randint(0, sideY - size, ()) # should be sideY
apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
apper = torch.nn.functional.interpolate(apper, (224,224), mode='bilinear')
p_s.append(apper)
into = torch.cat(p_s, 0)
# old
#into = torch.nn.functional.interpolate(out, (224,224), mode='nearest')
# end of old
into = nom(into)
iii = perceptor.encode_image(into)
#llls = lats()
lat_l = 0
return [lat_l, 10*-torch.cosine_similarity(t, iii).view(-1, 1).T.mean(1)]
def train(i):
loss1 = ascend_txt()
loss = loss1[0] + loss1[1]
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if itt % 100 == 0:
checkin(i, loss1)
#
# Begin
#
text = "Moscow never sleeps"
print ('Text:', text)
print ('Loading models and stuff')
lats = Pars().cuda()
mapper = [lats.normu]
optimizer = torch.optim.Adam([{'params': mapper, 'lr': .075}]) #was .1
model = load_model("decoder.pkl", 'cuda')
#model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda')
print ('Generator loaded')
perceptor, preprocess = clip.load('ViT-B/32', jit=True)
perceptor = perceptor.eval()
#im_shape = [512, 512, 3]
im_shape = [512, 512, 3]
sideX, sideY, channels = im_shape
tx = clip.tokenize(text)
t = perceptor.encode_text(tx.cuda()).detach().clone()
print ('Perceptor loaded')
nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
print("Starting")
itt = 0
for asatreat in range(10000):
train(itt)
itt+=1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment