Skip to content

Instantly share code, notes, and snippets.

@afiaka87
Created December 31, 2021 02:56
Show Gist options
  • Save afiaka87/30d0fef1392fec4d4236d11d8917177c to your computer and use it in GitHub Desktop.
Save afiaka87/30d0fef1392fec4d4236d11d8917177c to your computer and use it in GitHub Desktop.
Finetune GLIDE (small filtered) from Open AI. WIP.
import argparse
import sys
sys.path.append("./glide-text2im")
import torch as th
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (create_model_and_diffusion,
model_and_diffusion_defaults)
from guided_diffusion import dist_util, logger
from guided_diffusion.image_text_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import add_dict_to_argparser
from guided_diffusion.train_util import TrainLoop
def main():
args = create_argparser().parse_args()
dist_util.setup_dist()
logger.configure()
_device = dist_util.dev()
# Create base model.
glide_options = model_and_diffusion_defaults()
glide_model, diffusion = create_model_and_diffusion(**glide_options)
glide_model.convert_to_fp16()
glide_model.to(_device)
glide_model.load_state_dict(load_checkpoint('base', _device))
logger.log('total base parameters', sum(x.numel()
for x in glide_model.parameters()))
schedule_sampler = create_named_schedule_sampler(
args.schedule_sampler, diffusion)
logger.log("creating data loader...")
data = load_latent_data(
data_dir=args.data_dir,
batch_size=args.batch_size,
model=glide_model,
options=glide_options,
device=_device,
)
logger.log("training...")
TrainLoop(
model=glide_model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
).run_loop()
def load_latent_data(model, options, data_dir, batch_size, device):
data = load_data(
data_dir=data_dir,
batch_size=batch_size,
image_size=64,
class_cond=False,
)
for batch, model_kwargs, text in data:
tokens = model.tokenizer.encode(text[0])
tokens, mask = model.tokenizer.padded_tokens_and_mask(
tokens, options['text_ctx'])
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
[], options['text_ctx'])
tokens = th.tensor([tokens] * batch_size + [uncond_tokens]
* batch_size, device=device, dtype=th.half)
mask = th.tensor([mask] * batch_size + [uncond_mask]
* batch_size, dtype=th.bool, device=device)
# model_kwargs["xf_proj"] = tokens
# model_kwargs["xf_out"] = uncond_tokens
model_kwargs["tokens"] = tokens
model_kwargs["mask"] = mask
batch = batch.to(dist_util.dev())
yield batch, model_kwargs
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=10,
save_interval=10000,
resume_checkpoint="",
use_fp16=True,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults())
defaults['encoder_channels'] = 512
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment