Skip to content

Instantly share code, notes, and snippets.

@nousr
Created June 25, 2022 01:28
Show Gist options
  • Save nousr/fcbbb72ddcfc3e12e4ce895a6b4865f1 to your computer and use it in GitHub Desktop.
Save nousr/fcbbb72ddcfc3e12e4ce895a6b4865f1 to your computer and use it in GitHub Desktop.
import os
import numpy as np
import json
import click
from clip import tokenize
from dalle2_pytorch.trainer import DiffusionPriorTrainer
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
def get_prior(path, device):
prior_network = DiffusionPriorNetwork(
dim=768,
depth=24,
dim_head=64,
heads=32,
normformer=True,
attn_dropout=5e-2,
ff_dropout=5e-2,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
num_timesteps=1000,
ff_mult=4,
).to(device)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter("ViT-L/14"),
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
).to(device)
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=1.1e-4,
wd=6.02e-2,
max_grad_norm=0.5,
amp=False,
group_wd_params=True,
use_ema=True,
device=device,
accelerator=None,
)
trainer.load(path)
trainer.ema_diffusion_prior.to(device)
trainer.eval()
trainer.diffusion_prior.eval()
trainer.ema_diffusion_prior.eval()
trainer.to(device)
return trainer
@click.command()
@click.option(
"--parent-folder", default="prompt_embeddings", help="Name of parent folder"
)
@click.option(
"--prompt-list", default="prompts.json", help="Path to json list of prompts"
)
@click.option("--n-samples", default=9, help="Number of samples to generate per prompt")
@click.option(
"--model", default="prior.pth", help="Name of diffusion prior trainer to load"
)
@click.option("--device", default="cuda", help="which device to use for inference")
def main(parent_folder, prompt_list, n_samples, model, device):
# load prompts
with open(prompt_list, "r") as f:
prompts = json.loads(f.read())
# load model
print(f"loading diffusion prior model [{model}]...", end="")
prior = get_prior(model, device=device)
print("done!")
# make embedding folder if it doesn't exist
os.makedirs(parent_folder, exist_ok=True)
for i, prompt in enumerate(prompts, start=4):
# create the subfolder
os.makedirs(f"{parent_folder}/prompt_{i:03d}", exist_ok=True)
# dump the prompt
with open(f"{parent_folder}/prompt_{i:03d}/prompt_{i:03d}.txt", "w") as f:
f.write(prompt + "\n")
# sample embeddings
embedding = prior.sample(tokenize(prompt).repeat(n_samples, 1), cond_scale=1.0)
np.save(
f"{parent_folder}/prompt_{i:03d}/prompt_{i:03d}_batch.npy",
embedding.detach().cpu().numpy(),
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment