Created
August 29, 2022 14:32
-
-
Save td2sk/32c81ffbe86c6de1a1bea796939b4410 to your computer and use it in GitHub Desktop.
jupyter notebook for Stable Diffusion txt2img
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from scripts.txt2img_for_notebook import load_model, main" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = load_model()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prompt = \"pyramid\"\n", | |
"prompt_correction =[\"egypt::-0.5\", \"japan::0.5\"]\n", | |
"seed = 42\n", | |
"\n", | |
"images = main(model, prompt, seed=seed, plms=True, prompt_correction=prompt_correction)\n", | |
"images[0]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.8.5 ('ldm')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "9b1f5017147893acf565d1e383c3b688375657195e3a8ca3aea8d088cd976bc0" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import glob | |
import os | |
import sys | |
import time | |
from contextlib import contextmanager, nullcontext | |
from itertools import islice | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.util import instantiate_from_config | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from torch import autocast | |
from torchvision.utils import make_grid | |
from tqdm.auto import tqdm, trange | |
def chunk(it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def load_model_from_config(config, ckpt, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
sd = pl_sd["state_dict"] | |
model = instantiate_from_config(config.model) | |
model.load_state_dict(sd, strict=False) | |
model.cuda() | |
model.eval() | |
return model | |
def load_model(laion400m=False, config="configs/stable-diffusion/v1-inference.yaml", ckpt="models/ldm/stable-diffusion-v1/model.ckpt"): | |
if laion400m: | |
print("Falling back to LAION 400M model...") | |
config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" | |
ckpt = "models/ldm/text2img-large/model.ckpt" | |
config = OmegaConf.load(f"{config}") | |
model = load_model_from_config(config, f"{ckpt}") | |
# use float16 model (for VRAM 8GB environment) | |
model = model.to(torch.float16) | |
return model | |
def main(model, prompt: str, ddim_steps=50, fixed_code=False, plms=False, ddim_eta=0.0, n_iter=1, H=512, W=512, C=4, f=8, n_samples=1, n_rows=0, scale=7.5, seed=42, precision="autocast", prompt_correction=[]): | |
seed_everything(seed) | |
device = torch.device( | |
"cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
if plms: | |
sampler = PLMSSampler(model) | |
else: | |
sampler = DDIMSampler(model) | |
batch_size = n_samples | |
n_rows = n_rows if n_rows > 0 else batch_size | |
assert prompt is not None | |
data = [batch_size * [prompt]] | |
start_code = None | |
if fixed_code: | |
start_code = torch.randn( | |
[n_samples, C, H // f, W // f], device=device) | |
images = [] | |
precision_scope = autocast if precision == "autocast" else nullcontext | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
tic = time.time() | |
for n in trange(n_iter, desc="Sampling"): | |
for prompts in tqdm(data): | |
uc = None | |
if scale != 1.0: | |
uc = model.get_learned_conditioning( | |
batch_size * [""]) | |
if isinstance(prompts, tuple): | |
prompts = list(prompts) | |
c = model.get_learned_conditioning(prompts) | |
for pw in prompt_correction: | |
pw = pw.split('::') | |
p, weight = pw[:-1], float(pw[-1]) | |
c += weight * \ | |
model.get_learned_conditioning(list(p)) | |
shape = [C, H // f, W // f] | |
samples_ddim, _ = sampler.sample(S=ddim_steps, | |
conditioning=c, | |
batch_size=n_samples, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc, | |
eta=ddim_eta, | |
x_T=start_code) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
x_samples_ddim = torch.clamp( | |
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
for x_sample in x_samples_ddim: | |
x_sample = 255. * \ | |
rearrange(x_sample.cpu().numpy(), | |
'c h w -> h w c') | |
images.append(Image.fromarray( | |
x_sample.astype(np.uint8))) | |
toc = time.time() | |
return images | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment