Created
July 4, 2022 18:28
-
-
Save nousr/452b8cec7326c8587a03aa6876e9d148 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import click | |
import clip | |
import numpy as np | |
from PIL import Image | |
import torch | |
import wandb | |
from dalle2_pytorch import ( | |
DALLE2, | |
DiffusionPrior, | |
DiffusionPriorNetwork, | |
OpenAIClipAdapter, | |
train_configs, | |
) | |
from torchvision.utils import make_grid | |
from torchvision.transforms.functional import to_pil_image | |
from dalle2_pytorch.train_configs import TrainDecoderConfig | |
from dalle2_pytorch.trainer import DiffusionPriorTrainer | |
from torchvision.transforms.functional import to_pil_image | |
def load_decoder(decoder_state_dict_path, config_file_path, device): | |
config = train_configs.TrainDecoderConfig.from_json_path(config_file_path) | |
decoder = config.decoder.create().to(device) | |
decoder_state_dict = torch.load(decoder_state_dict_path, map_location="cpu") | |
decoder.load_state_dict(decoder_state_dict, strict=False) | |
del decoder_state_dict | |
decoder.eval() | |
return decoder | |
def format_images(batch: torch.Tensor, nrow: int = 3, resize: bool = False): | |
grid = make_grid(batch, nrow=nrow) | |
if resize == False: | |
return to_pil_image(grid) | |
elif resize == True: | |
return to_pil_image(grid).resize((512, 512), Image.Resampling.NEAREST) | |
@click.command() | |
@click.option("--checkpoint", default="decoder.pth") | |
@click.option("--config", default="config.json") | |
@click.option("--embedding-folder", default="prompt_embeddings") | |
@click.option("--prompt", default=0) | |
@click.option("--cond-scale", default=5.0) | |
def main(checkpoint, config, embedding_folder, prompt, cond_scale): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load checkpoint | |
decoder = load_decoder(checkpoint, config, device) | |
for i in range(5): | |
# Load embeddings | |
embeddings = np.load( | |
f"{embedding_folder}/prompt_{i+prompt:03d}/prompt_{i+prompt:03d}_batch.npy" | |
) | |
embeddings = torch.from_numpy(embeddings).to(device) | |
# Sample | |
images = decoder.sample(image_embed=embeddings, cond_scale=cond_scale) | |
# Make Grids | |
actual_grid = format_images(batch=images, resize=False) | |
resized_grid = format_images(batch=images, resize=True) | |
# Save | |
save_path = os.path.join(embedding_folder, f"prompt_{i+prompt:03d}", "predictions") | |
os.makedirs(save_path, exist_ok=True) | |
actual_grid.save(f"{save_path}/grid_raw.png") | |
resized_grid.save(f"{save_path}/grid_resized.png") | |
for idx, img in enumerate(images): | |
img = img.clamp(0, 1) | |
to_pil_image(img).save(f"{save_path}/prompt_{idx:03d}_out.png") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment