Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Created January 18, 2022 06:22
Show Gist options
  • Save afiaka87/c1be1571043d8dc9c267e0b58195dfbc to your computer and use it in GitHub Desktop.
Save afiaka87/c1be1571043d8dc9c267e0b58195dfbc to your computer and use it in GitHub Desktop.
Finetune GLIDE on a captioned-images dataset e.g. COCO/LAION
# https://wandb.ai/afiaka87/glide_finetune/runs/3fj69lfc?workspace=user-afiaka87
from lzma import MODE_NORMAL
from PIL import Image
import os
import wandb
from IPython.display import display
import torch as th
from glide_text2im import xf
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
create_model_and_diffusion,
model_and_diffusion_defaults,
model_and_diffusion_defaults_upsampler
)
import torch as th
import numpy as np
from PIL import Image
from loader import TextImageDataset
import bitsandbytes as bnb
from tqdm import trange, tqdm
import gc
from ipywidgets import Output
from IPython.display import display
from matplotlib import pyplot as plt
from IPython.display import clear_output
from torch.cuda.amp import autocast
# import glide_text2im
# %%
has_cuda = th.cuda.is_available()
fp16 = False # fp16 is bad for this. perhaps due to low batch size/high noise schedule?
device = th.device('cpu' if not has_cuda else 'cuda')
# %%
# Create base model.
options = model_and_diffusion_defaults()
options['use_fp16'] = False
options['cache_text_emb'] = False
# options['use_checkpoint'] = True
options['use_fp16'] = has_cuda and fp16
options['dropout'] = 0.1
options['timestep_respacing'] = '100'
# use 100 diffusion steps for fast sampling
model, diffusion = create_model_and_diffusion(**options)
model.train()
model.requires_grad_(True)
# model.transformer.requires_grad_(True)
# model.train()
if has_cuda and fp16:
model.convert_to_fp16()
model.to(device)
model.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in model.parameters() if x.requires_grad))
print(f'transformer params: {sum(x.numel() for x in model.transformer.parameters() if x.requires_grad)}')
# %%
def show_images(batch: th.Tensor):
""" Display a batch of images inline. """
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
display(Image.fromarray(reshaped.numpy()))
# %%
batch_size = 1
grad_acc = 4
guidance_scale = 3.0
learning_rate = 1e-6
side_x = 64
side_y = 64
upsample_x = 4
base_dir = './finetune_checkpoints'
os.makedirs(base_dir, exist_ok=True)
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
dataset = TextImageDataset(
folder="/home/samsepiol/DatasetWorkspace/CurrentDatasets",
shuffle=True,
batch_size=batch_size,
device=device,
)
assert len(dataset) > 0, "Dataset is empty"
print(f"Dataset contains {len(dataset)} images")
def _extract_into_tensor(arr, timesteps, broadcast_shape):
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
print(f"Dataset has {len(dataset)} images")
dataloader = th.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
print(f"Dataset has {len(dataloader)} batches")
def prompt_to_model_kwargs(prompt: str = '', _batch_size: int = 1, device: str = 'cpu'):
prompt = prompt.lower()
assert len(prompt) > 0, 'prompt must be a non-empty string'
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(tokens, options['text_ctx'])
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask([], options['text_ctx'])
return dict(
tokens=th.tensor(
[tokens] * _batch_size +
[uncond_tokens] * _batch_size,
device=device
),
mask=th.tensor(
[mask] * _batch_size +
[uncond_mask] * _batch_size,
dtype=th.bool,
device=device
),
)
optim = bnb.optim.Adam8bit([x for x in model.parameters() if x.requires_grad], lr=learning_rate)
out = Output()
display(out)
losses = []
l = 0
# bar = trange(train_steps)
full_batch_size = batch_size * 2
config = {
'batch_size': batch_size,
'grad_acc': grad_acc,
'side_x': side_x,
'side_y': side_y,
'learning_rate': learning_rate,
}
log = {}
wandb_run = wandb.init(project="glide_finetune", config=config)
try:
for i, (captions, images) in tqdm(enumerate(dataloader), total=len(dataloader)):
images = images.to(device)
for prompt, x in zip(captions, images):
x = x.repeat((full_batch_size, 1, 1, 1))
model_kwargs = prompt_to_model_kwargs(prompt=prompt,_batch_size=batch_size, device=device)
ts = th.randint(0, 99, (full_batch_size,)).to(device)
noise_variance = _extract_into_tensor(diffusion.betas, ts, x.shape)
orig_noise = th.randn_like(x, device=x.device)
noise = (noise_variance ** 0.5).to(x.device) * orig_noise
output = model(x + noise, ts * 10, **model_kwargs)
eps = output[..., :3, :, :]
loss = th.nn.functional.mse_loss(eps, orig_noise)
l += loss.item()
loss.backward()
if i % 1000 == 0:
model.state_dict()
model_dict = {
'weights': model.state_dict(),
'optim': optim.state_dict(),
'options': options,
}
th.save(model_dict, os.path.join(base_dir, f'glide-ft-{i}.pt'))
th.save(model_dict, os.path.join(base_dir, f'glide-ft.pt'))
print(f'Saved checkpoint {i} to {base_dir}/glide-ft-{i}.pt')
if i % grad_acc == grad_acc - 1:
optim.step()
optim.zero_grad()
l /= grad_acc
losses.append(l)
with out:
clear_output(wait=True)
wandb_run.log({"loss": l})
l = 0 # TODO important, otherwise it will accumulate
except KeyboardInterrupt:
pass
print("Interrupted")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment