Skip to content

Instantly share code, notes, and snippets.

@robgon-art
Last active July 16, 2022 04:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robgon-art/920ac282f92f5f1ae1946458b0dfeb88 to your computer and use it in GitHub Desktop.
Save robgon-art/920ac282f92f5f1ae1946458b0dfeb88 to your computer and use it in GitHub Desktop.
create a gradient image with CLIP
# Copyright © 2022 Robert A. Gonsalves
# Released under CC BY-SA 4.0
# https://creativecommons.org/licenses/by-sa/4.0/
import torchvision.transforms as T
import torch
prompt = "penguins skiing down a snowy mountain"
num_steps = 100
init_rand_amount = 0.25
text_input = clip.tokenize(prompt).to(device)
with torch.no_grad():
text_features = model.encode_text(text_input)
augment_trans = T.Compose([
T.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
T.RandomResizedCrop(224, scale=(0.7,0.9)),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
bg_x = torch.linspace(0, 223, num_ctrl_ponts).to(device)
bg_y = (0.5 + init_rand_amount/2.0 * torch.rand(size=(3, num_ctrl_ponts))).to(device)
bg_y.requires_grad = True
bg_xs = torch.linspace(0, 223, 224).to(device)
bg_vars = [bg_y]
bgvals = bg_vars[0]
bg_optim = torch.optim.Adam(bg_vars, lr=learning_rate)
loss_fn = torch.nn.CosineEmbeddingLoss()
target = torch.full((1,32), fill_value=1.0).squeeze().to(device)
# Run the main optimization loop
for t in range(num_steps+1):
bg_optim.zero_grad()
img_0 = interp(bg_x.cpu(), bgvals[0].cpu(), bg_xs.cpu()).to(device)
img_1 = interp(bg_x.cpu(), bgvals[1].cpu(), bg_xs.cpu()).to(device)
img_2 = interp(bg_x.cpu(), bgvals[2].cpu(), bg_xs.cpu()).to(device)
img = torch.vstack([img_0, img_1, img_2])
img = img.permute(1,0)
img = img.tile((224, 1, 1))
img = img.unsqueeze(0)
img = img.permute(0, 3, 2, 1) # NHWC -> NCHW
img_augs = []
for n in range(num_augmentations):
img_augs.append(augment_trans(img))
im_batch = torch.cat(img_augs)
image_features = model.encode_image(im_batch)
loss = loss_fn(image_features, text_features, target)
loss.backward()
bg_optim.step()
if t % 10 == 0:
print("-" * 10)
image = img.detach().cpu().numpy()
image = np.transpose(image, (0, 2, 3, 1))[0]
image = np.clip(image*255, 0, 255).astype(np.uint8)
image_pil = Image.fromarray(image)
print('render loss:', loss.item())
print('iteration:', t)
image_pil = Image.fromarray(image)
img = plt.imshow(image_pil)
plt.axis('off')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment