Skip to content

Instantly share code, notes, and snippets.

@nousr
Created July 4, 2022 18:28
Show Gist options
  • Save nousr/452b8cec7326c8587a03aa6876e9d148 to your computer and use it in GitHub Desktop.
Save nousr/452b8cec7326c8587a03aa6876e9d148 to your computer and use it in GitHub Desktop.
import os
import click
import clip
import numpy as np
from PIL import Image
import torch
import wandb
from dalle2_pytorch import (
DALLE2,
DiffusionPrior,
DiffusionPriorNetwork,
OpenAIClipAdapter,
train_configs,
)
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.trainer import DiffusionPriorTrainer
from torchvision.transforms.functional import to_pil_image
def load_decoder(decoder_state_dict_path, config_file_path, device):
config = train_configs.TrainDecoderConfig.from_json_path(config_file_path)
decoder = config.decoder.create().to(device)
decoder_state_dict = torch.load(decoder_state_dict_path, map_location="cpu")
decoder.load_state_dict(decoder_state_dict, strict=False)
del decoder_state_dict
decoder.eval()
return decoder
def format_images(batch: torch.Tensor, nrow: int = 3, resize: bool = False):
grid = make_grid(batch, nrow=nrow)
if resize == False:
return to_pil_image(grid)
elif resize == True:
return to_pil_image(grid).resize((512, 512), Image.Resampling.NEAREST)
@click.command()
@click.option("--checkpoint", default="decoder.pth")
@click.option("--config", default="config.json")
@click.option("--embedding-folder", default="prompt_embeddings")
@click.option("--prompt", default=0)
@click.option("--cond-scale", default=5.0)
def main(checkpoint, config, embedding_folder, prompt, cond_scale):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load checkpoint
decoder = load_decoder(checkpoint, config, device)
for i in range(5):
# Load embeddings
embeddings = np.load(
f"{embedding_folder}/prompt_{i+prompt:03d}/prompt_{i+prompt:03d}_batch.npy"
)
embeddings = torch.from_numpy(embeddings).to(device)
# Sample
images = decoder.sample(image_embed=embeddings, cond_scale=cond_scale)
# Make Grids
actual_grid = format_images(batch=images, resize=False)
resized_grid = format_images(batch=images, resize=True)
# Save
save_path = os.path.join(embedding_folder, f"prompt_{i+prompt:03d}", "predictions")
os.makedirs(save_path, exist_ok=True)
actual_grid.save(f"{save_path}/grid_raw.png")
resized_grid.save(f"{save_path}/grid_resized.png")
for idx, img in enumerate(images):
img = img.clamp(0, 1)
to_pil_image(img).save(f"{save_path}/prompt_{idx:03d}_out.png")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment